{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TypeOperators #-} module AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where import AST import AST.Sparse.Types import Data -- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by -- expanding them into their concrete implementations. Also ensure that -- 'EAccum' has a dense sparsity. unMonoid :: Ex env t -> Ex env t unMonoid = \case EZero _ t e -> zero t e EDeepZero _ t e -> deepZero t e EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) EVar _ t i -> EVar ext t i ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) EPair _ a b -> EPair ext (unMonoid a) (unMonoid b) EFst _ e -> EFst ext (unMonoid e) ESnd _ e -> ESnd ext (unMonoid e) ENil _ -> ENil ext EInl _ t e -> EInl ext t (unMonoid e) EInr _ t e -> EInr ext t (unMonoid e) ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b) ENothing _ t -> ENothing ext t EJust _ e -> EJust ext (unMonoid e) EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e) ELNil _ t1 t2 -> ELNil ext t1 t2 ELInl _ t e -> ELInl ext t (unMonoid e) ELInr _ t e -> ELInr ext t (unMonoid e) ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c) EConstArr _ n t x -> EConstArr ext n t x EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c) ESum1Inner _ e -> ESum1Inner ext (unMonoid e) EUnit _ e -> EUnit ext (unMonoid e) EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) EConst _ t x -> EConst ext t x EIdx0 _ e -> EIdx0 ext (unMonoid e) EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b) EShape _ e -> EShape ext (unMonoid e) EOp _ op e -> EOp ext op (unMonoid e) ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) ERecompute _ e -> ERecompute ext (unMonoid e) EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) EAccum _ t p eidx sp eval eacc -> accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 -> acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' -> EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc)) EError _ t s -> EError ext t s zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t zero SMTNil _ = ENil ext zero (SMTPair t1 t2) e = ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ))) (zero t2 (ESnd ext (EVar ext (typeOf e) IZ))) zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2) zero (SMTMaybe t) _ = ENothing ext (fromSMTy t) zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e zero (SMTScal t) _ = case t of STI32 -> EConst ext STI32 0 STI64 -> EConst ext STI64 0 STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t deepZero SMTNil _ = ENil ext deepZero (SMTPair t1 t2) e = ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ))) (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ))) deepZero (SMTLEither t1 t2) e = elcase e (ELNil ext (fromSMTy t1) (fromSMTy t2)) (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ))) (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ))) deepZero (SMTMaybe t) e = emaybe e (ENothing ext (fromSMTy t)) (EJust ext (deepZero t (evar IZ))) deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e deepZero (SMTScal t) _ = case t of STI32 -> EConst ext STI32 0 STI64 -> EConst ext STI64 0 STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t plus SMTNil _ _ = ENil ext plus (SMTPair t1 t2) a b = let t = STPair (fromSMTy t1) (fromSMTy t2) in ELet ext a $ ELet ext (weakenExpr WSink b) $ EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ))) (EFst ext (EVar ext t IZ))) (plus t2 (ESnd ext (EVar ext t (IS IZ))) (ESnd ext (EVar ext t IZ))) plus (SMTLEither t1 t2) a b = let t = STLEither (fromSMTy t1) (fromSMTy t2) in ELet ext a $ ELet ext (weakenExpr WSink b) $ ELCase ext (EVar ext t (IS IZ)) (EVar ext t IZ) (ELCase ext (EVar ext t (IS IZ)) (EVar ext t (IS (IS IZ))) (ELInl ext (fromSMTy t2) (plus t1 (EVar ext (fromSMTy t1) (IS IZ)) (EVar ext (fromSMTy t1) IZ))) (EError ext t "plus l+r")) (ELCase ext (EVar ext t (IS IZ)) (EVar ext t (IS (IS IZ))) (EError ext t "plus r+l") (ELInr ext (fromSMTy t1) (plus t2 (EVar ext (fromSMTy t2) (IS IZ)) (EVar ext (fromSMTy t2) IZ)))) plus (SMTMaybe t) a b = ELet ext b $ EMaybe ext (EVar ext (STMaybe (fromSMTy t)) IZ) (EJust ext (EMaybe ext (EVar ext (fromSMTy t) IZ) (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) (EVar ext (STMaybe (fromSMTy t)) (IS IZ)))) (weakenExpr WSink a) plus (SMTArr _ t) a b = ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) a b plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b) onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t onehot typ topprj idx arg = case (typ, topprj) of (_, SAPHere) -> ELet ext arg $ EVar ext (fromSMTy typ) IZ (SMTPair t1 t2, SAPFst prj) -> ELet ext idx $ let tidx = typeOf idx in ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ let toh = fromSMTy t1 in EPair ext (EVar ext toh IZ) (zero t2 (ESnd ext (EVar ext tidx (IS IZ)))) (SMTPair t1 t2, SAPSnd prj) -> ELet ext idx $ let tidx = typeOf idx in ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ let toh = fromSMTy t2 in EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ)))) (EVar ext toh IZ) (SMTLEither t1 t2, SAPLeft prj) -> ELInl ext (fromSMTy t2) (onehot t1 prj idx arg) (SMTLEither t1 t2, SAPRight prj) -> ELInr ext (fromSMTy t1) (onehot t2 prj idx arg) (SMTMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg) (SMTArr n t1, SAPArrIdx prj) -> let tidx = tTup (sreplicate n tIx) in ELet ext idx $ EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $ eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $ zero t1 (EVar ext (tZeroInfo t1) IZ)) accumulateSparse :: SMTy t -> Sparse t t' -> Ex env t' -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil) -> Ex env TNil accumulateSparse topty topsp arg accum = case (topty, topsp) of (_, s) | Just Refl <- isDense topty s -> accum WId SAPHere (ENil ext) arg (SMTScal _, SpScal) -> accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh (_, SpSparse s) -> emaybe arg (ENil ext) (accumulateSparse topty s (evar IZ) (\w -> accum (WPop w))) (_, SpAbsent) -> ENil ext (SMTPair t1 t2, SpPair s1 s2) -> eunPair arg $ \w1 e1 e2 -> elet (accumulateSparse t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $ accumulateSparse t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj)) (SMTLEither t1 t2, SpLEither s1 s2) -> elcase arg (ENil ext) (accumulateSparse t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj))) (accumulateSparse t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj))) (SMTMaybe t, SpMaybe s) -> emaybe arg (ENil ext) (accumulateSparse t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj))) (SMTArr n t, SpArr s) -> let tn = tTup (sreplicate n tIx) in elet arg $ elet (EBuild ext n (EShape ext (evar IZ)) $ accumulateSparse t s (EIdx ext (evar (IS IZ)) (EVar ext tn IZ)) (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $ ENil ext acPrjCompose :: SAIDense dense -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a) -> SAcPrj p2 b c -> Ex env (AcIdx dense p2 b) -> (forall p'. SAcPrj p' a c -> Ex env (AcIdx dense p' a) -> r) -> r acPrjCompose _ SAPHere _ p2 idx2 k = k p2 idx2 acPrjCompose SAID (SAPFst p1) idx1 p2 idx2 k = acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> k (SAPFst p') idx' acPrjCompose SAID (SAPSnd p1) idx1 p2 idx2 k = acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> k (SAPSnd p') idx' acPrjCompose SAIS (SAPFst p1) idx1 p2 idx2 k | Dict <- styKnown (typeOf idx1) = acPrjCompose SAIS p1 (efst (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> k (SAPFst p') (elet idx1 $ EPair ext idx' (esnd (evar IZ))) acPrjCompose SAIS (SAPSnd p1) idx1 p2 idx2 k | Dict <- styKnown (typeOf idx1) = acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> k (SAPSnd p') (elet idx1 $ EPair ext (efst (evar IZ)) idx') acPrjCompose d (SAPLeft p1) idx1 p2 idx2 k = acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> k (SAPLeft p') idx' acPrjCompose d (SAPRight p1) idx1 p2 idx2 k = acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> k (SAPRight p') idx' acPrjCompose d (SAPJust p1) idx1 p2 idx2 k = acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> k (SAPJust p') idx' acPrjCompose SAID (SAPArrIdx p1) idx1 p2 idx2 k | Dict <- styKnown (typeOf idx1) = acPrjCompose SAID p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k | Dict <- styKnown (typeOf idx1) = acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx')