aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/AST/UnMonoid.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD/AST/UnMonoid.hs')
-rw-r--r--src/CHAD/AST/UnMonoid.hs190
1 files changed, 129 insertions, 61 deletions
diff --git a/src/CHAD/AST/UnMonoid.hs b/src/CHAD/AST/UnMonoid.hs
index 06de00c..6578438 100644
--- a/src/CHAD/AST/UnMonoid.hs
+++ b/src/CHAD/AST/UnMonoid.hs
@@ -2,7 +2,9 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
module CHAD.AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where
import CHAD.AST
@@ -10,40 +12,106 @@ import CHAD.AST.Sparse.Types
import CHAD.Data
--- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by
--- expanding them into their concrete implementations. Also ensure that
--- 'EAccum' has a dense sparsity.
-unMonoid :: Ex env t -> Ex env t
+type family UnMonoid t where
+ UnMonoid (TIdxPair n t) = TPair (Tup (Replicate n TIx)) (UnMonoid t)
+
+ UnMonoid TNil = TNil
+ UnMonoid (TPair a b) = TPair (UnMonoid a) (UnMonoid b)
+ UnMonoid (TEither a b) = TEither (UnMonoid a) (UnMonoid b)
+ UnMonoid (TLEither a b) = TLEither (UnMonoid a) (UnMonoid b)
+ UnMonoid (TMaybe a) = TMaybe (UnMonoid a)
+ UnMonoid (TArr n a) = TArr n (UnMonoid a)
+ UnMonoid (TScal t) = TScal t
+ UnMonoid (TAccum t) = TAccum (UnMonoid t)
+
+type family UnMonoidE env where
+ UnMonoidE '[] = '[]
+ UnMonoidE (t : ts) = UnMonoid t : UnMonoidE ts
+
+tUnMonoid :: STy t -> STy (UnMonoid t)
+tUnMonoid (STIdxPair n t) = STPair (tTup (sreplicate n tIx)) (tUnMonoid t)
+tUnMonoid STNil = STNil
+tUnMonoid (STPair a b) = STPair (tUnMonoid a) (tUnMonoid b)
+tUnMonoid (STEither a b) = STEither (tUnMonoid a) (tUnMonoid b)
+tUnMonoid (STLEither a b) = STLEither (tUnMonoid a) (tUnMonoid b)
+tUnMonoid (STMaybe a) = STMaybe (tUnMonoid a)
+tUnMonoid (STArr n t) = STArr n (tUnMonoid t)
+tUnMonoid (STAccum t) = STAccum (mtUnMonoid t)
+tUnMonoid (STScal t) = STScal t
+
+mtUnMonoid :: SMTy t -> SMTy (UnMonoid t)
+mtUnMonoid (SMTIdxPair n t) = SMTPair (mkTup SMTNil SMTPair (sreplicate n (SMTScal STI64))) (mtUnMonoid t)
+mtUnMonoid SMTNil = SMTNil
+mtUnMonoid (SMTPair a b) = SMTPair (mtUnMonoid a) (mtUnMonoid b)
+mtUnMonoid (SMTLEither a b) = SMTLEither (mtUnMonoid a) (mtUnMonoid b)
+mtUnMonoid (SMTMaybe a) = SMTMaybe (mtUnMonoid a)
+mtUnMonoid (SMTArr n t) = SMTArr n (mtUnMonoid t)
+mtUnMonoid (SMTScal t) = SMTScal t
+
+unMonoidIndexId :: SNat n -> UnMonoid (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx)
+unMonoidIndexId SZ = Refl
+unMonoidIndexId (SS n) | Refl <- unMonoidIndexId n = Refl
+
+unMonoidZeroInfoId :: SMTy t -> UnMonoid (ZeroInfo t) :~: ZeroInfo t
+unMonoidZeroInfoId SMTNil = Refl
+unMonoidZeroInfoId (SMTPair a b) | Refl <- unMonoidZeroInfoId a, Refl <- unMonoidZeroInfoId b = Refl
+unMonoidZeroInfoId (SMTLEither a b) | Refl <- unMonoidZeroInfoId a, Refl <- unMonoidZeroInfoId b = Refl
+unMonoidZeroInfoId (SMTMaybe a) | Refl <- unMonoidZeroInfoId a = Refl
+unMonoidZeroInfoId (SMTArr _ a) | Refl <- unMonoidZeroInfoId a = Refl
+unMonoidZeroInfoId (SMTScal _) = Refl
+unMonoidZeroInfoId (SMTIdxPair n a) | Refl <- unMonoidIndexId n, Refl <- unMonoidZeroInfoId a = Refl
+
+unMonoidDeepZeroInfoId :: SMTy t -> UnMonoid (DeepZeroInfo t) :~: DeepZeroInfo t
+unMonoidDeepZeroInfoId SMTNil = Refl
+unMonoidDeepZeroInfoId (SMTPair a b) | Refl <- unMonoidDeepZeroInfoId a, Refl <- unMonoidDeepZeroInfoId b = Refl
+unMonoidDeepZeroInfoId (SMTLEither a b) | Refl <- unMonoidDeepZeroInfoId a, Refl <- unMonoidDeepZeroInfoId b = Refl
+unMonoidDeepZeroInfoId (SMTMaybe a) | Refl <- unMonoidDeepZeroInfoId a = Refl
+unMonoidDeepZeroInfoId (SMTArr _ a) | Refl <- unMonoidDeepZeroInfoId a = Refl
+unMonoidDeepZeroInfoId (SMTScal _) = Refl
+unMonoidDeepZeroInfoId (SMTIdxPair n a) | Refl <- unMonoidIndexId n, Refl <- unMonoidDeepZeroInfoId a = Refl
+
+
+-- | Removes monoidal stuff from the program. In particular:
+--
+-- * Removes 'EZero', 'EDeepZero', 'EPlus', 'EOneHot' from the program by
+-- expanding them into their concrete implementations.
+-- * Removes 'EIdxPair' and 'EUnIdxPair', as well as all occurrences of the
+-- 'TIdxPair' type in general.
+-- * Ensures that 'EAccum' has a dense sparsity.
+unMonoid :: Ex env t -> Ex (UnMonoidE env) (UnMonoid t)
unMonoid = \case
- EZero _ t e -> zero t e
- EDeepZero _ t e -> deepZero t e
- EPlus _ t a b -> plus t (unMonoid a) (unMonoid b)
+ EZero _ t e | Refl <- unMonoidZeroInfoId t -> zero t (unMonoid e)
+ EDeepZero _ t e | Refl <- unMonoidDeepZeroInfoId t -> deepZero t (unMonoid e)
+ EPlus _ t a b -> plus (mtUnMonoid t) (unMonoid a) (unMonoid b)
EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b)
+ EIdxPair _ n a b | Refl <- unMonoidIndexId n -> EPair ext (unMonoid a) (unMonoid b)
+ EUnIdxPair _ e | STIdxPair n _ <- typeOf e, Refl <- unMonoidIndexId n -> unMonoid e
EAccum _ t p eidx sp eval eacc ->
- elet (unMonoid eacc) $
- elet (weakenExpr WSink (unMonoid eidx)) $
- accumulateSparse (acPrjTy p t) sp (weakenExpr (WSink .> WSink) (unMonoid eval)) $ \w prj2 idx2 val2 ->
- acPrjCompose SAID p (evar (w @> IZ)) prj2 idx2 $ \prj' idx' ->
- EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (evar (w @> IS IZ))
+ accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 ->
+ acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' ->
+ EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc))
- EVar _ t i -> EVar ext t i
+ EVar _ t i -> EVar ext (tUnMonoid t) (go i)
+ where go :: Idx env t -> Idx (UnMonoidE env) (UnMonoid t)
+ go IZ = IZ
+ go (IS j) = IS (go j)
ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body)
EPair _ a b -> EPair ext (unMonoid a) (unMonoid b)
EFst _ e -> EFst ext (unMonoid e)
ESnd _ e -> ESnd ext (unMonoid e)
ENil _ -> ENil ext
- EInl _ t e -> EInl ext t (unMonoid e)
- EInr _ t e -> EInr ext t (unMonoid e)
+ EInl _ t e -> EInl ext (tUnMonoid t) (unMonoid e)
+ EInr _ t e -> EInr ext (tUnMonoid t) (unMonoid e)
ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b)
- ENothing _ t -> ENothing ext t
+ ENothing _ t -> ENothing ext (tUnMonoid t)
EJust _ e -> EJust ext (unMonoid e)
EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e)
- ELNil _ t1 t2 -> ELNil ext t1 t2
- ELInl _ t e -> ELInl ext t (unMonoid e)
- ELInr _ t e -> ELInr ext t (unMonoid e)
+ ELNil _ t1 t2 -> ELNil ext (tUnMonoid t1) (tUnMonoid t2)
+ ELInl _ t e -> ELInl ext (tUnMonoid t) (unMonoid e)
+ ELInr _ t e -> ELInr ext (tUnMonoid t) (unMonoid e)
ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c)
EConstArr _ n t x -> EConstArr ext n t x
- EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b)
+ EBuild _ n a b | Refl <- unMonoidIndexId n -> EBuild ext n (unMonoid a) (unMonoid b)
EMap _ a b -> EMap ext (unMonoid a) (unMonoid b)
EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c)
ESum1Inner _ e -> ESum1Inner ext (unMonoid e)
@@ -51,49 +119,52 @@ unMonoid = \case
EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b)
EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)
EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
- EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b)
+ EReshape _ n a b | Refl <- unMonoidIndexId n -> EReshape ext n (unMonoid a) (unMonoid b)
EZip _ a b -> EZip ext (unMonoid a) (unMonoid b)
EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
EConst _ t x -> EConst ext t x
EIdx0 _ e -> EIdx0 ext (unMonoid e)
EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b)
- EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b)
- EShape _ e -> EShape ext (unMonoid e)
+ EIdx _ a b | STArr n _ <- typeOf a, Refl <- unMonoidIndexId n -> EIdx ext (unMonoid a) (unMonoid b)
+ EShape _ e | STArr n _ <- typeOf e, Refl <- unMonoidIndexId n -> EShape ext (unMonoid e)
EOp _ op e -> EOp ext op (unMonoid e)
- ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
+ ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext (tUnMonoid t1) (tUnMonoid t2) (tUnMonoid t3) (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
ERecompute _ e -> ERecompute ext (unMonoid e)
EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b)
EError _ t s -> EError ext t s
-zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t
+zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env (UnMonoid t)
-- don't destroy the effects!
-zero SMTNil e = ELet ext e $ ENil ext
+zero SMTNil e = use e $ ENil ext
zero (SMTPair t1 t2) e =
ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ)))
(zero t2 (ESnd ext (EVar ext (typeOf e) IZ)))
-zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2)
-zero (SMTMaybe t) _ = ENothing ext (fromSMTy t)
+zero (SMTLEither t1 t2) _ = ELNil ext (tUnMonoid (fromSMTy t1)) (tUnMonoid (fromSMTy t2))
+zero (SMTMaybe t) _ = ENothing ext (tUnMonoid (fromSMTy t))
zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e
zero (SMTScal t) _ = case t of
STI32 -> EConst ext STI32 0
STI64 -> EConst ext STI64 0
STF32 -> EConst ext STF32 0.0
STF64 -> EConst ext STF64 0.0
+zero (SMTIdxPair _ t) e =
+ eunPair e $ \_ e1 e2 ->
+ EPair ext e1 (zero t e2)
-deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t
+deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env (UnMonoid t)
deepZero SMTNil e = elet e $ ENil ext
deepZero (SMTPair t1 t2) e =
ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ)))
(deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ)))
deepZero (SMTLEither t1 t2) e =
elcase e
- (ELNil ext (fromSMTy t1) (fromSMTy t2))
- (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ)))
- (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ)))
+ (ELNil ext (tUnMonoid (fromSMTy t1)) (tUnMonoid (fromSMTy t2)))
+ (ELInl ext (tUnMonoid (fromSMTy t2)) (deepZero t1 (evar IZ)))
+ (ELInr ext (tUnMonoid (fromSMTy t1)) (deepZero t2 (evar IZ)))
deepZero (SMTMaybe t) e =
emaybe e
- (ENothing ext (fromSMTy t))
+ (ENothing ext (tUnMonoid (fromSMTy t)))
(EJust ext (deepZero t (evar IZ)))
deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e
deepZero (SMTScal t) _ = case t of
@@ -101,6 +172,9 @@ deepZero (SMTScal t) _ = case t of
STI64 -> EConst ext STI64 0
STF32 -> EConst ext STF32 0.0
STF64 -> EConst ext STF64 0.0
+deepZero (SMTIdxPair _ t) e =
+ eunPair e $ \_ e1 e2 ->
+ EPair ext e1 (deepZero t e2)
plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t
-- don't destroy the effects!
@@ -134,44 +208,38 @@ plus (SMTArr _ t) a b =
a b
plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b)
-onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t
+onehot :: SMTy t -> SAcPrj p t a -> Ex env (UnMonoid (AcIdxS p t)) -> Ex env (UnMonoid a) -> Ex env (UnMonoid t)
onehot typ topprj idx arg = case (typ, topprj) of
(_, SAPHere) ->
- ELet ext arg $
- EVar ext (fromSMTy typ) IZ
+ arg
+
+ (SMTPair t1 t2, SAPFst prj) | Refl <- unMonoidZeroInfoId t2 ->
+ elet idx $
+ elet (onehot t1 prj (EFst ext (evar IZ)) (weakenExpr WSink arg)) $
+ EPair ext (evar IZ)
+ (zero t2 (ESnd ext (evar (IS IZ))))
- (SMTPair t1 t2, SAPFst prj) ->
- ELet ext idx $
- let tidx = typeOf idx in
- ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $
- let toh = fromSMTy t1 in
- EPair ext (EVar ext toh IZ)
- (zero t2 (ESnd ext (EVar ext tidx (IS IZ))))
-
- (SMTPair t1 t2, SAPSnd prj) ->
- ELet ext idx $
- let tidx = typeOf idx in
- ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $
- let toh = fromSMTy t2 in
- EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ))))
- (EVar ext toh IZ)
+ (SMTPair t1 t2, SAPSnd prj) | Refl <- unMonoidZeroInfoId t1 ->
+ elet idx $
+ elet (onehot t2 prj (ESnd ext (evar IZ)) (weakenExpr WSink arg)) $
+ EPair ext (zero t1 (EFst ext (evar (IS IZ))))
+ (evar IZ)
(SMTLEither t1 t2, SAPLeft prj) ->
- ELInl ext (fromSMTy t2) (onehot t1 prj idx arg)
+ ELInl ext (tUnMonoid (fromSMTy t2)) (onehot t1 prj idx arg)
(SMTLEither t1 t2, SAPRight prj) ->
- ELInr ext (fromSMTy t1) (onehot t2 prj idx arg)
+ ELInr ext (tUnMonoid (fromSMTy t1)) (onehot t2 prj idx arg)
(SMTMaybe t1, SAPJust prj) ->
EJust ext (onehot t1 prj idx arg)
- (SMTArr n t1, SAPArrIdx prj) ->
- let tidx = tTup (sreplicate n tIx)
- in ELet ext idx $
- EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $
- eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))))
- (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
- (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $
- zero t1 (EVar ext (tZeroInfo t1) IZ))
+ (SMTArr n t1, SAPArrIdx prj) | Refl <- unMonoidIndexId n, Refl <- unMonoidZeroInfoId t1 ->
+ elet idx $
+ EBuild ext n (EShape ext (ESnd ext (EFst ext (evar IZ)))) $
+ eif (eidxEq n (evar IZ) (EFst ext (EFst ext (evar (IS IZ)))))
+ (onehot t1 prj (ESnd ext (evar (IS IZ))) (weakenExpr (WSink .> WSink) arg))
+ (elet (EIdx ext (ESnd ext (EFst ext (evar (IS IZ)))) (evar IZ)) $
+ zero t1 (evar IZ))
accumulateSparse
:: SMTy t -> Sparse t t' -> Ex env t'