{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TypeOperators #-} module CHAD.AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where import CHAD.AST import CHAD.AST.Sparse.Types import CHAD.Data -- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by -- expanding them into their concrete implementations. Also ensure that -- 'EAccum' has a dense sparsity. unMonoid :: Ex env t -> Ex env t unMonoid = \case EZero _ t e -> zero t e EDeepZero _ t e -> deepZero t e EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) EVar _ t i -> EVar ext t i ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) EPair _ a b -> EPair ext (unMonoid a) (unMonoid b) EFst _ e -> EFst ext (unMonoid e) ESnd _ e -> ESnd ext (unMonoid e) ENil _ -> ENil ext EInl _ t e -> EInl ext t (unMonoid e) EInr _ t e -> EInr ext t (unMonoid e) ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b) ENothing _ t -> ENothing ext t EJust _ e -> EJust ext (unMonoid e) EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e) ELNil _ t1 t2 -> ELNil ext t1 t2 ELInl _ t e -> ELInl ext t (unMonoid e) ELInr _ t e -> ELInr ext t (unMonoid e) ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c) EConstArr _ n t x -> EConstArr ext n t x EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) EMap _ a b -> EMap ext (unMonoid a) (unMonoid b) EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c) ESum1Inner _ e -> ESum1Inner ext (unMonoid e) EUnit _ e -> EUnit ext (unMonoid e) EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b) EZip _ a b -> EZip ext (unMonoid a) (unMonoid b) EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c) EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) EConst _ t x -> EConst ext t x EIdx0 _ e -> EIdx0 ext (unMonoid e) EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b) EShape _ e -> EShape ext (unMonoid e) EOp _ op e -> EOp ext op (unMonoid e) ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) ERecompute _ e -> ERecompute ext (unMonoid e) EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) EAccum _ t p eidx sp eval eacc -> elet (unMonoid eacc) $ elet (weakenExpr WSink (unMonoid eidx)) $ accumulateSparse (acPrjTy p t) sp (weakenExpr (WSink .> WSink) (unMonoid eval)) $ \w prj2 idx2 val2 -> acPrjCompose SAID p (evar (w @> IZ)) prj2 idx2 $ \prj' idx' -> EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (evar (w @> IS IZ)) EError _ t s -> EError ext t s zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t -- don't destroy the effects! zero SMTNil e = ELet ext e $ ENil ext zero (SMTPair t1 t2) e = ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ))) (zero t2 (ESnd ext (EVar ext (typeOf e) IZ))) zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2) zero (SMTMaybe t) _ = ENothing ext (fromSMTy t) zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e zero (SMTScal t) _ = case t of STI32 -> EConst ext STI32 0 STI64 -> EConst ext STI64 0 STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t deepZero SMTNil e = elet e $ ENil ext deepZero (SMTPair t1 t2) e = ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ))) (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ))) deepZero (SMTLEither t1 t2) e = elcase e (ELNil ext (fromSMTy t1) (fromSMTy t2)) (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ))) (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ))) deepZero (SMTMaybe t) e = emaybe e (ENothing ext (fromSMTy t)) (EJust ext (deepZero t (evar IZ))) deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e deepZero (SMTScal t) _ = case t of STI32 -> EConst ext STI32 0 STI64 -> EConst ext STI64 0 STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t -- don't destroy the effects! plus SMTNil a b = use a $ use b $ ENil ext plus (SMTPair t1 t2) a b = eunPair a $ \w1 a1 a2 -> eunPair (weakenExpr w1 b) $ \w2 b1 b2 -> EPair ext (plus t1 (weakenExpr w2 a1) b1) (plus t2 (weakenExpr w2 a2) b2) plus (SMTLEither t1 t2) a b = elet b $ elcase (weakenExpr WSink a) (evar IZ) (elcase (evar (IS IZ)) (ELInl ext (fromSMTy t2) (evar IZ)) (ELInl ext (fromSMTy t2) (plus t1 (evar (IS IZ)) (evar IZ))) (EError ext (fromSMTy (SMTLEither t1 t2)) "splus ll+lr")) (elcase (evar (IS IZ)) (ELInr ext (fromSMTy t1) (evar IZ)) (EError ext (fromSMTy (SMTLEither t1 t2)) "splus lr+ll") (ELInr ext (fromSMTy t1) (plus t2 (evar (IS IZ)) (evar IZ)))) plus (SMTMaybe t) a b = elet b $ emaybe (weakenExpr WSink a) (evar IZ) (emaybe (evar (IS IZ)) (EJust ext (evar IZ)) (EJust ext (plus t (evar (IS IZ)) (evar IZ)))) plus (SMTArr _ t) a b = ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) a b plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b) onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t onehot typ topprj idx arg = case (typ, topprj) of (_, SAPHere) -> ELet ext arg $ EVar ext (fromSMTy typ) IZ (SMTPair t1 t2, SAPFst prj) -> ELet ext idx $ let tidx = typeOf idx in ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ let toh = fromSMTy t1 in EPair ext (EVar ext toh IZ) (zero t2 (ESnd ext (EVar ext tidx (IS IZ)))) (SMTPair t1 t2, SAPSnd prj) -> ELet ext idx $ let tidx = typeOf idx in ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ let toh = fromSMTy t2 in EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ)))) (EVar ext toh IZ) (SMTLEither t1 t2, SAPLeft prj) -> ELInl ext (fromSMTy t2) (onehot t1 prj idx arg) (SMTLEither t1 t2, SAPRight prj) -> ELInr ext (fromSMTy t1) (onehot t2 prj idx arg) (SMTMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg) (SMTArr n t1, SAPArrIdx prj) -> let tidx = tTup (sreplicate n tIx) in ELet ext idx $ EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $ eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $ zero t1 (EVar ext (tZeroInfo t1) IZ)) accumulateSparse :: SMTy t -> Sparse t t' -> Ex env t' -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil) -> Ex env TNil accumulateSparse topty topsp arg accum = case (topty, topsp) of (_, s) | Just Refl <- isDense topty s -> accum WId SAPHere (ENil ext) arg (SMTScal _, SpScal) -> accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh (_, SpSparse s) -> emaybe arg (ENil ext) (accumulateSparse topty s (evar IZ) (\w -> accum (WPop w))) (_, SpAbsent) -> ENil ext (SMTPair t1 t2, SpPair s1 s2) -> eunPair arg $ \w1 e1 e2 -> elet (accumulateSparse t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $ accumulateSparse t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj)) (SMTLEither t1 t2, SpLEither s1 s2) -> elcase arg (ENil ext) (accumulateSparse t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj))) (accumulateSparse t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj))) (SMTMaybe t, SpMaybe s) -> emaybe arg (ENil ext) (accumulateSparse t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj))) (SMTArr n t, SpArr s) -> let tn = tTup (sreplicate n tIx) in elet arg $ elet (EBuild ext n (EShape ext (evar IZ)) $ accumulateSparse t s (EIdx ext (evar (IS IZ)) (EVar ext tn IZ)) (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $ ENil ext acPrjCompose :: SAIDense dense -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a) -> SAcPrj p2 b c -> Ex env (AcIdx dense p2 b) -> (forall p'. SAcPrj p' a c -> Ex env (AcIdx dense p' a) -> r) -> r acPrjCompose _ SAPHere _ p2 idx2 k = k p2 idx2 acPrjCompose SAID (SAPFst p1) idx1 p2 idx2 k = acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> k (SAPFst p') idx' acPrjCompose SAID (SAPSnd p1) idx1 p2 idx2 k = acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> k (SAPSnd p') idx' acPrjCompose SAIS (SAPFst p1) idx1 p2 idx2 k | Dict <- styKnown (typeOf idx1) = acPrjCompose SAIS p1 (efst (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> k (SAPFst p') (elet idx1 $ EPair ext idx' (esnd (evar IZ))) acPrjCompose SAIS (SAPSnd p1) idx1 p2 idx2 k | Dict <- styKnown (typeOf idx1) = acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> k (SAPSnd p') (elet idx1 $ EPair ext (efst (evar IZ)) idx') acPrjCompose d (SAPLeft p1) idx1 p2 idx2 k = acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> k (SAPLeft p') idx' acPrjCompose d (SAPRight p1) idx1 p2 idx2 k = acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> k (SAPRight p') idx' acPrjCompose d (SAPJust p1) idx1 p2 idx2 k = acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> k (SAPJust p') idx' acPrjCompose SAID (SAPArrIdx p1) idx1 p2 idx2 k | Dict <- styKnown (typeOf idx1) = acPrjCompose SAID p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k | Dict <- styKnown (typeOf idx1) = acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx')