{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module CHAD.AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where import CHAD.AST import CHAD.AST.Sparse.Types import CHAD.Data type family UnMonoid t where UnMonoid (TIdxPair n t) = TPair (Tup (Replicate n TIx)) (UnMonoid t) UnMonoid TNil = TNil UnMonoid (TPair a b) = TPair (UnMonoid a) (UnMonoid b) UnMonoid (TEither a b) = TEither (UnMonoid a) (UnMonoid b) UnMonoid (TLEither a b) = TLEither (UnMonoid a) (UnMonoid b) UnMonoid (TMaybe a) = TMaybe (UnMonoid a) UnMonoid (TArr n a) = TArr n (UnMonoid a) UnMonoid (TScal t) = TScal t UnMonoid (TAccum t) = TAccum (UnMonoid t) type family UnMonoidE env where UnMonoidE '[] = '[] UnMonoidE (t : ts) = UnMonoid t : UnMonoidE ts tUnMonoid :: STy t -> STy (UnMonoid t) tUnMonoid (STIdxPair n t) = STPair (tTup (sreplicate n tIx)) (tUnMonoid t) tUnMonoid STNil = STNil tUnMonoid (STPair a b) = STPair (tUnMonoid a) (tUnMonoid b) tUnMonoid (STEither a b) = STEither (tUnMonoid a) (tUnMonoid b) tUnMonoid (STLEither a b) = STLEither (tUnMonoid a) (tUnMonoid b) tUnMonoid (STMaybe a) = STMaybe (tUnMonoid a) tUnMonoid (STArr n t) = STArr n (tUnMonoid t) tUnMonoid (STAccum t) = STAccum (mtUnMonoid t) tUnMonoid (STScal t) = STScal t mtUnMonoid :: SMTy t -> SMTy (UnMonoid t) mtUnMonoid (SMTIdxPair n t) = SMTPair (mkTup SMTNil SMTPair (sreplicate n (SMTScal STI64))) (mtUnMonoid t) mtUnMonoid SMTNil = SMTNil mtUnMonoid (SMTPair a b) = SMTPair (mtUnMonoid a) (mtUnMonoid b) mtUnMonoid (SMTLEither a b) = SMTLEither (mtUnMonoid a) (mtUnMonoid b) mtUnMonoid (SMTMaybe a) = SMTMaybe (mtUnMonoid a) mtUnMonoid (SMTArr n t) = SMTArr n (mtUnMonoid t) mtUnMonoid (SMTScal t) = SMTScal t unMonoidIndexId :: SNat n -> UnMonoid (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx) unMonoidIndexId SZ = Refl unMonoidIndexId (SS n) | Refl <- unMonoidIndexId n = Refl unMonoidZeroInfoId :: SMTy t -> UnMonoid (ZeroInfo t) :~: ZeroInfo t unMonoidZeroInfoId SMTNil = Refl unMonoidZeroInfoId (SMTPair a b) | Refl <- unMonoidZeroInfoId a, Refl <- unMonoidZeroInfoId b = Refl unMonoidZeroInfoId (SMTLEither a b) | Refl <- unMonoidZeroInfoId a, Refl <- unMonoidZeroInfoId b = Refl unMonoidZeroInfoId (SMTMaybe a) | Refl <- unMonoidZeroInfoId a = Refl unMonoidZeroInfoId (SMTArr _ a) | Refl <- unMonoidZeroInfoId a = Refl unMonoidZeroInfoId (SMTScal _) = Refl unMonoidZeroInfoId (SMTIdxPair n a) | Refl <- unMonoidIndexId n, Refl <- unMonoidZeroInfoId a = Refl unMonoidDeepZeroInfoId :: SMTy t -> UnMonoid (DeepZeroInfo t) :~: DeepZeroInfo t unMonoidDeepZeroInfoId SMTNil = Refl unMonoidDeepZeroInfoId (SMTPair a b) | Refl <- unMonoidDeepZeroInfoId a, Refl <- unMonoidDeepZeroInfoId b = Refl unMonoidDeepZeroInfoId (SMTLEither a b) | Refl <- unMonoidDeepZeroInfoId a, Refl <- unMonoidDeepZeroInfoId b = Refl unMonoidDeepZeroInfoId (SMTMaybe a) | Refl <- unMonoidDeepZeroInfoId a = Refl unMonoidDeepZeroInfoId (SMTArr _ a) | Refl <- unMonoidDeepZeroInfoId a = Refl unMonoidDeepZeroInfoId (SMTScal _) = Refl unMonoidDeepZeroInfoId (SMTIdxPair n a) | Refl <- unMonoidIndexId n, Refl <- unMonoidDeepZeroInfoId a = Refl -- | Removes monoidal stuff from the program. In particular: -- -- * Removes 'EZero', 'EDeepZero', 'EPlus', 'EOneHot' from the program by -- expanding them into their concrete implementations. -- * Removes 'EIdxPair' and 'EUnIdxPair', as well as all occurrences of the -- 'TIdxPair' type in general. -- * Ensures that 'EAccum' has a dense sparsity. unMonoid :: Ex env t -> Ex (UnMonoidE env) (UnMonoid t) unMonoid = \case EZero _ t e | Refl <- unMonoidZeroInfoId t -> zero t (unMonoid e) EDeepZero _ t e | Refl <- unMonoidDeepZeroInfoId t -> deepZero t (unMonoid e) EPlus _ t a b -> plus (mtUnMonoid t) (unMonoid a) (unMonoid b) EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) EIdxPair _ n a b | Refl <- unMonoidIndexId n -> EPair ext (unMonoid a) (unMonoid b) EUnIdxPair _ e | STIdxPair n _ <- typeOf e, Refl <- unMonoidIndexId n -> unMonoid e EAccum _ t p eidx sp eval eacc -> accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 -> acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' -> EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc)) EVar _ t i -> EVar ext (tUnMonoid t) (go i) where go :: Idx env t -> Idx (UnMonoidE env) (UnMonoid t) go IZ = IZ go (IS j) = IS (go j) ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) EPair _ a b -> EPair ext (unMonoid a) (unMonoid b) EFst _ e -> EFst ext (unMonoid e) ESnd _ e -> ESnd ext (unMonoid e) ENil _ -> ENil ext EInl _ t e -> EInl ext (tUnMonoid t) (unMonoid e) EInr _ t e -> EInr ext (tUnMonoid t) (unMonoid e) ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b) ENothing _ t -> ENothing ext (tUnMonoid t) EJust _ e -> EJust ext (unMonoid e) EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e) ELNil _ t1 t2 -> ELNil ext (tUnMonoid t1) (tUnMonoid t2) ELInl _ t e -> ELInl ext (tUnMonoid t) (unMonoid e) ELInr _ t e -> ELInr ext (tUnMonoid t) (unMonoid e) ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c) EConstArr _ n t x -> EConstArr ext n t x EBuild _ n a b | Refl <- unMonoidIndexId n -> EBuild ext n (unMonoid a) (unMonoid b) EMap _ a b -> EMap ext (unMonoid a) (unMonoid b) EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c) ESum1Inner _ e -> ESum1Inner ext (unMonoid e) EUnit _ e -> EUnit ext (unMonoid e) EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) EReshape _ n a b | Refl <- unMonoidIndexId n -> EReshape ext n (unMonoid a) (unMonoid b) EZip _ a b -> EZip ext (unMonoid a) (unMonoid b) EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c) EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) EConst _ t x -> EConst ext t x EIdx0 _ e -> EIdx0 ext (unMonoid e) EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) EIdx _ a b | STArr n _ <- typeOf a, Refl <- unMonoidIndexId n -> EIdx ext (unMonoid a) (unMonoid b) EShape _ e | STArr n _ <- typeOf e, Refl <- unMonoidIndexId n -> EShape ext (unMonoid e) EOp _ op e -> EOp ext op (unMonoid e) ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext (tUnMonoid t1) (tUnMonoid t2) (tUnMonoid t3) (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) ERecompute _ e -> ERecompute ext (unMonoid e) EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) EError _ t s -> EError ext t s zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env (UnMonoid t) -- don't destroy the effects! zero SMTNil e = use e $ ENil ext zero (SMTPair t1 t2) e = ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ))) (zero t2 (ESnd ext (EVar ext (typeOf e) IZ))) zero (SMTLEither t1 t2) _ = ELNil ext (tUnMonoid (fromSMTy t1)) (tUnMonoid (fromSMTy t2)) zero (SMTMaybe t) _ = ENothing ext (tUnMonoid (fromSMTy t)) zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e zero (SMTScal t) _ = case t of STI32 -> EConst ext STI32 0 STI64 -> EConst ext STI64 0 STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 zero (SMTIdxPair _ t) e = eunPair e $ \_ e1 e2 -> EPair ext e1 (zero t e2) deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env (UnMonoid t) deepZero SMTNil e = elet e $ ENil ext deepZero (SMTPair t1 t2) e = ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ))) (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ))) deepZero (SMTLEither t1 t2) e = elcase e (ELNil ext (tUnMonoid (fromSMTy t1)) (tUnMonoid (fromSMTy t2))) (ELInl ext (tUnMonoid (fromSMTy t2)) (deepZero t1 (evar IZ))) (ELInr ext (tUnMonoid (fromSMTy t1)) (deepZero t2 (evar IZ))) deepZero (SMTMaybe t) e = emaybe e (ENothing ext (tUnMonoid (fromSMTy t))) (EJust ext (deepZero t (evar IZ))) deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e deepZero (SMTScal t) _ = case t of STI32 -> EConst ext STI32 0 STI64 -> EConst ext STI64 0 STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 deepZero (SMTIdxPair _ t) e = eunPair e $ \_ e1 e2 -> EPair ext e1 (deepZero t e2) plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t -- don't destroy the effects! plus SMTNil a b = use a $ use b $ ENil ext plus (SMTPair t1 t2) a b = eunPair a $ \w1 a1 a2 -> eunPair (weakenExpr w1 b) $ \w2 b1 b2 -> EPair ext (plus t1 (weakenExpr w2 a1) b1) (plus t2 (weakenExpr w2 a2) b2) plus (SMTLEither t1 t2) a b = elet b $ elcase (weakenExpr WSink a) (evar IZ) (elcase (evar (IS IZ)) (ELInl ext (fromSMTy t2) (evar IZ)) (ELInl ext (fromSMTy t2) (plus t1 (evar (IS IZ)) (evar IZ))) (EError ext (fromSMTy (SMTLEither t1 t2)) "splus ll+lr")) (elcase (evar (IS IZ)) (ELInr ext (fromSMTy t1) (evar IZ)) (EError ext (fromSMTy (SMTLEither t1 t2)) "splus lr+ll") (ELInr ext (fromSMTy t1) (plus t2 (evar (IS IZ)) (evar IZ)))) plus (SMTMaybe t) a b = elet b $ emaybe (weakenExpr WSink a) (evar IZ) (emaybe (evar (IS IZ)) (EJust ext (evar IZ)) (EJust ext (plus t (evar (IS IZ)) (evar IZ)))) plus (SMTArr _ t) a b = ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) a b plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b) onehot :: SMTy t -> SAcPrj p t a -> Ex env (UnMonoid (AcIdxS p t)) -> Ex env (UnMonoid a) -> Ex env (UnMonoid t) onehot typ topprj idx arg = case (typ, topprj) of (_, SAPHere) -> arg (SMTPair t1 t2, SAPFst prj) | Refl <- unMonoidZeroInfoId t2 -> elet idx $ elet (onehot t1 prj (EFst ext (evar IZ)) (weakenExpr WSink arg)) $ EPair ext (evar IZ) (zero t2 (ESnd ext (evar (IS IZ)))) (SMTPair t1 t2, SAPSnd prj) | Refl <- unMonoidZeroInfoId t1 -> elet idx $ elet (onehot t2 prj (ESnd ext (evar IZ)) (weakenExpr WSink arg)) $ EPair ext (zero t1 (EFst ext (evar (IS IZ)))) (evar IZ) (SMTLEither t1 t2, SAPLeft prj) -> ELInl ext (tUnMonoid (fromSMTy t2)) (onehot t1 prj idx arg) (SMTLEither t1 t2, SAPRight prj) -> ELInr ext (tUnMonoid (fromSMTy t1)) (onehot t2 prj idx arg) (SMTMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg) (SMTArr n t1, SAPArrIdx prj) | Refl <- unMonoidIndexId n, Refl <- unMonoidZeroInfoId t1 -> elet idx $ EBuild ext n (EShape ext (ESnd ext (EFst ext (evar IZ)))) $ eif (eidxEq n (evar IZ) (EFst ext (EFst ext (evar (IS IZ))))) (onehot t1 prj (ESnd ext (evar (IS IZ))) (weakenExpr (WSink .> WSink) arg)) (elet (EIdx ext (ESnd ext (EFst ext (evar (IS IZ)))) (evar IZ)) $ zero t1 (evar IZ)) accumulateSparse :: SMTy t -> Sparse t t' -> Ex env t' -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil) -> Ex env TNil accumulateSparse topty topsp arg accum = case (topty, topsp) of (_, s) | Just Refl <- isDense topty s -> accum WId SAPHere (ENil ext) arg (SMTScal _, SpScal) -> accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh (_, SpSparse s) -> emaybe arg (ENil ext) (accumulateSparse topty s (evar IZ) (\w -> accum (WPop w))) (_, SpAbsent) -> ENil ext (SMTPair t1 t2, SpPair s1 s2) -> eunPair arg $ \w1 e1 e2 -> elet (accumulateSparse t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $ accumulateSparse t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj)) (SMTLEither t1 t2, SpLEither s1 s2) -> elcase arg (ENil ext) (accumulateSparse t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj))) (accumulateSparse t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj))) (SMTMaybe t, SpMaybe s) -> emaybe arg (ENil ext) (accumulateSparse t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj))) (SMTArr n t, SpArr s) -> let tn = tTup (sreplicate n tIx) in elet arg $ elet (EBuild ext n (EShape ext (evar IZ)) $ accumulateSparse t s (EIdx ext (evar (IS IZ)) (EVar ext tn IZ)) (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $ ENil ext (SMTArr _ t, SpArrIdx s) -> eunPair arg $ \w1 e1 e2 -> elet (accumulateSparse t s e2 (\w prj idx val -> accum (w .> w1) (SAPArrIdx prj) (EPair ext (weakenExpr w e1) idx) val)) $ ENil ext acPrjCompose :: SAIDense dense -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a) -> SAcPrj p2 b c -> Ex env (AcIdx dense p2 b) -> (forall p'. SAcPrj p' a c -> Ex env (AcIdx dense p' a) -> r) -> r acPrjCompose _ SAPHere _ p2 idx2 k = k p2 idx2 acPrjCompose SAID (SAPFst p1) idx1 p2 idx2 k = acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> k (SAPFst p') idx' acPrjCompose SAID (SAPSnd p1) idx1 p2 idx2 k = acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> k (SAPSnd p') idx' acPrjCompose SAIS (SAPFst p1) idx1 p2 idx2 k | Dict <- styKnown (typeOf idx1) = acPrjCompose SAIS p1 (efst (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> k (SAPFst p') (elet idx1 $ EPair ext idx' (esnd (evar IZ))) acPrjCompose SAIS (SAPSnd p1) idx1 p2 idx2 k | Dict <- styKnown (typeOf idx1) = acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> k (SAPSnd p') (elet idx1 $ EPair ext (efst (evar IZ)) idx') acPrjCompose d (SAPLeft p1) idx1 p2 idx2 k = acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> k (SAPLeft p') idx' acPrjCompose d (SAPRight p1) idx1 p2 idx2 k = acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> k (SAPRight p') idx' acPrjCompose d (SAPJust p1) idx1 p2 idx2 k = acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> k (SAPJust p') idx' acPrjCompose SAID (SAPArrIdx p1) idx1 p2 idx2 k | Dict <- styKnown (typeOf idx1) = acPrjCompose SAID p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k | Dict <- styKnown (typeOf idx1) = acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx')