aboutsummaryrefslogtreecommitdiff
path: root/src/AST/UnMonoid.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST/UnMonoid.hs')
-rw-r--r--src/AST/UnMonoid.hs253
1 files changed, 183 insertions, 70 deletions
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index ae9728a..48dd709 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -1,17 +1,22 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeOperators #-}
-module AST.UnMonoid (unMonoid, zero, plus) where
+module AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where
import AST
-import CHAD.Types
+import AST.Sparse.Types
import Data
+-- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by
+-- expanding them into their concrete implementations. Also ensure that
+-- 'EAccum' has a dense sparsity.
unMonoid :: Ex env t -> Ex env t
unMonoid = \case
- EZero _ t -> zero t
+ EZero _ t e -> zero t e
+ EDeepZero _ t e -> deepZero t e
EPlus _ t a b -> plus t (unMonoid a) (unMonoid b)
EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b)
@@ -27,9 +32,13 @@ unMonoid = \case
ENothing _ t -> ENothing ext t
EJust _ e -> EJust ext (unMonoid e)
EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e)
+ ELNil _ t1 t2 -> ELNil ext t1 t2
+ ELInl _ t e -> ELInl ext t (unMonoid e)
+ ELInr _ t e -> ELInr ext t (unMonoid e)
+ ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c)
EConstArr _ n t x -> EConstArr ext n t x
EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b)
- EFold1Inner _ a b c -> EFold1Inner ext (unMonoid a) (unMonoid b) (unMonoid c)
+ EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c)
ESum1Inner _ e -> ESum1Inner ext (unMonoid e)
EUnit _ e -> EUnit ext (unMonoid e)
EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b)
@@ -42,96 +51,200 @@ unMonoid = \case
EShape _ e -> EShape ext (unMonoid e)
EOp _ op e -> EOp ext op (unMonoid e)
ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
+ ERecompute _ e -> ERecompute ext (unMonoid e)
EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b)
- EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e)
+ EAccum _ t p eidx sp eval eacc ->
+ accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 ->
+ acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' ->
+ EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc))
EError _ t s -> EError ext t s
-zero :: STy t -> Ex env (D2 t)
-zero STNil = ENil ext
-zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2))
-zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2))
-zero (STMaybe t) = ENothing ext (d2 t)
-zero (STArr SZ t) = EUnit ext (zero t)
-zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (EError ext (d2 t) "empty")
-zero (STScal t) = case t of
- STI32 -> ENil ext
- STI64 -> ENil ext
+zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t
+-- don't destroy the effects!
+zero SMTNil e = ELet ext e $ ENil ext
+zero (SMTPair t1 t2) e =
+ ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ)))
+ (zero t2 (ESnd ext (EVar ext (typeOf e) IZ)))
+zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2)
+zero (SMTMaybe t) _ = ENothing ext (fromSMTy t)
+zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e
+zero (SMTScal t) _ = case t of
+ STI32 -> EConst ext STI32 0
+ STI64 -> EConst ext STI64 0
STF32 -> EConst ext STF32 0.0
STF64 -> EConst ext STF64 0.0
- STBool -> ENil ext
-zero STAccum{} = error "Accumulators not allowed in input program"
-plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t)
-plus STNil _ _ = ENil ext
-plus (STPair t1 t2) a b =
- let t = STPair (d2 t1) (d2 t2)
- in plusSparse t a b $
+deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t
+deepZero SMTNil e = elet e $ ENil ext
+deepZero (SMTPair t1 t2) e =
+ ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ)))
+ (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ)))
+deepZero (SMTLEither t1 t2) e =
+ elcase e
+ (ELNil ext (fromSMTy t1) (fromSMTy t2))
+ (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ)))
+ (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ)))
+deepZero (SMTMaybe t) e =
+ emaybe e
+ (ENothing ext (fromSMTy t))
+ (EJust ext (deepZero t (evar IZ)))
+deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e
+deepZero (SMTScal t) _ = case t of
+ STI32 -> EConst ext STI32 0
+ STI64 -> EConst ext STI64 0
+ STF32 -> EConst ext STF32 0.0
+ STF64 -> EConst ext STF64 0.0
+
+plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t
+-- don't destroy the effects!
+plus SMTNil a b = ELet ext a $ ELet ext (weakenExpr WSink b) $ ENil ext
+plus (SMTPair t1 t2) a b =
+ let t = STPair (fromSMTy t1) (fromSMTy t2)
+ in ELet ext a $
+ ELet ext (weakenExpr WSink b) $
EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ)))
(EFst ext (EVar ext t IZ)))
(plus t2 (ESnd ext (EVar ext t (IS IZ)))
(ESnd ext (EVar ext t IZ)))
-plus (STEither t1 t2) a b =
- let t = STEither (d2 t1) (d2 t2)
- in plusSparse t a b $
- ECase ext (EVar ext t (IS IZ))
- (ECase ext (EVar ext t (IS IZ))
- (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ)))
+plus (SMTLEither t1 t2) a b =
+ let t = STLEither (fromSMTy t1) (fromSMTy t2)
+ in ELet ext a $
+ ELet ext (weakenExpr WSink b) $
+ ELCase ext (EVar ext t (IS IZ))
+ (EVar ext t IZ)
+ (ELCase ext (EVar ext t (IS IZ))
+ (EVar ext t (IS (IS IZ)))
+ (ELInl ext (fromSMTy t2) (plus t1 (EVar ext (fromSMTy t1) (IS IZ)) (EVar ext (fromSMTy t1) IZ)))
(EError ext t "plus l+r"))
- (ECase ext (EVar ext t (IS IZ))
+ (ELCase ext (EVar ext t (IS IZ))
+ (EVar ext t (IS (IS IZ)))
(EError ext t "plus r+l")
- (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ))))
-plus (STMaybe t) a b =
- plusSparse (d2 t) a b $
- plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ)
-plus (STArr n t) a b =
- ELet ext a $
- ELet ext (weakenExpr WSink b) $
- eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) (IS IZ))))
- (EVar ext (STArr n (d2 t)) IZ)
- (eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) IZ)))
- (EVar ext (STArr n (d2 t)) (IS IZ))
- (ezipWith (plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ))
- (EVar ext (STArr n (d2 t)) (IS IZ))
- (EVar ext (STArr n (d2 t)) IZ)))
-plus (STScal t) a b = case t of
- STI32 -> ENil ext
- STI64 -> ENil ext
- STF32 -> EOp ext (OAdd STF32) (EPair ext a b)
- STF64 -> EOp ext (OAdd STF64) (EPair ext a b)
- STBool -> ENil ext
-plus STAccum{} _ _ = error "Accumulators not allowed in input program"
-
-plusSparse :: STy a
- -> Ex env (TMaybe a) -> Ex env (TMaybe a)
- -> Ex (a : a : env) a
- -> Ex env (TMaybe a)
-plusSparse t a b adder =
+ (ELInr ext (fromSMTy t1) (plus t2 (EVar ext (fromSMTy t2) (IS IZ)) (EVar ext (fromSMTy t2) IZ))))
+plus (SMTMaybe t) a b =
ELet ext b $
EMaybe ext
- (EVar ext (STMaybe t) IZ)
+ (EVar ext (STMaybe (fromSMTy t)) IZ)
(EJust ext
(EMaybe ext
- (EVar ext t IZ)
- (weakenExpr (WCopy (WCopy WSink)) adder)
- (EVar ext (STMaybe t) (IS IZ))))
+ (EVar ext (fromSMTy t) IZ)
+ (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ))
+ (EVar ext (STMaybe (fromSMTy t)) (IS IZ))))
(weakenExpr WSink a)
+plus (SMTArr _ t) a b =
+ ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ))
+ a b
+plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b)
-onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t)
+onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t
onehot typ topprj idx arg = case (typ, topprj) of
- (_, SAPHere) -> arg
+ (_, SAPHere) ->
+ ELet ext arg $
+ EVar ext (fromSMTy typ) IZ
- (STPair t1 t2, SAPFst prj) -> EJust ext (EPair ext (onehot t1 prj idx arg) (zero t2))
- (STPair t1 t2, SAPSnd prj) -> EJust ext (EPair ext (zero t1) (onehot t2 prj idx arg))
+ (SMTPair t1 t2, SAPFst prj) ->
+ ELet ext idx $
+ let tidx = typeOf idx in
+ ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $
+ let toh = fromSMTy t1 in
+ EPair ext (EVar ext toh IZ)
+ (zero t2 (ESnd ext (EVar ext tidx (IS IZ))))
+
+ (SMTPair t1 t2, SAPSnd prj) ->
+ ELet ext idx $
+ let tidx = typeOf idx in
+ ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $
+ let toh = fromSMTy t2 in
+ EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ))))
+ (EVar ext toh IZ)
- (STEither t1 t2, SAPLeft prj) -> EJust ext (EInl ext (d2 t2) (onehot t1 prj idx arg))
- (STEither t1 t2, SAPRight prj) -> EJust ext (EInr ext (d2 t1) (onehot t2 prj idx arg))
+ (SMTLEither t1 t2, SAPLeft prj) ->
+ ELInl ext (fromSMTy t2) (onehot t1 prj idx arg)
+ (SMTLEither t1 t2, SAPRight prj) ->
+ ELInr ext (fromSMTy t1) (onehot t2 prj idx arg)
- (STMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg)
+ (SMTMaybe t1, SAPJust prj) ->
+ EJust ext (onehot t1 prj idx arg)
- (STArr n t1, SAPArrIdx prj _) ->
+ (SMTArr n t1, SAPArrIdx prj) ->
let tidx = tTup (sreplicate n tIx)
in ELet ext idx $
- EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $
+ EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $
eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))))
(onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
- (zero t1)
+ (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $
+ zero t1 (EVar ext (tZeroInfo t1) IZ))
+
+accumulateSparse
+ :: SMTy t -> Sparse t t' -> Ex env t'
+ -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil)
+ -> Ex env TNil
+accumulateSparse topty topsp arg accum = case (topty, topsp) of
+ (_, s) | Just Refl <- isDense topty s ->
+ accum WId SAPHere (ENil ext) arg
+ (SMTScal _, SpScal) ->
+ accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh
+ (_, SpSparse s) ->
+ emaybe arg
+ (ENil ext)
+ (accumulateSparse topty s (evar IZ) (\w -> accum (WPop w)))
+ (_, SpAbsent) ->
+ ENil ext
+ (SMTPair t1 t2, SpPair s1 s2) ->
+ eunPair arg $ \w1 e1 e2 ->
+ elet (accumulateSparse t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $
+ accumulateSparse t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj))
+ (SMTLEither t1 t2, SpLEither s1 s2) ->
+ elcase arg
+ (ENil ext)
+ (accumulateSparse t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj)))
+ (accumulateSparse t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj)))
+ (SMTMaybe t, SpMaybe s) ->
+ emaybe arg
+ (ENil ext)
+ (accumulateSparse t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj)))
+ (SMTArr n t, SpArr s) ->
+ let tn = tTup (sreplicate n tIx) in
+ elet arg $
+ elet (EBuild ext n (EShape ext (evar IZ)) $
+ accumulateSparse t s
+ (EIdx ext (evar (IS IZ)) (EVar ext tn IZ))
+ (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $
+ ENil ext
+
+acPrjCompose
+ :: SAIDense dense
+ -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a)
+ -> SAcPrj p2 b c -> Ex env (AcIdx dense p2 b)
+ -> (forall p'. SAcPrj p' a c -> Ex env (AcIdx dense p' a) -> r) -> r
+acPrjCompose _ SAPHere _ p2 idx2 k = k p2 idx2
+acPrjCompose SAID (SAPFst p1) idx1 p2 idx2 k =
+ acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPFst p') idx'
+acPrjCompose SAID (SAPSnd p1) idx1 p2 idx2 k =
+ acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPSnd p') idx'
+acPrjCompose SAIS (SAPFst p1) idx1 p2 idx2 k
+ | Dict <- styKnown (typeOf idx1) =
+ acPrjCompose SAIS p1 (efst (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
+ k (SAPFst p') (elet idx1 $ EPair ext idx' (esnd (evar IZ)))
+acPrjCompose SAIS (SAPSnd p1) idx1 p2 idx2 k
+ | Dict <- styKnown (typeOf idx1) =
+ acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
+ k (SAPSnd p') (elet idx1 $ EPair ext (efst (evar IZ)) idx')
+acPrjCompose d (SAPLeft p1) idx1 p2 idx2 k =
+ acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPLeft p') idx'
+acPrjCompose d (SAPRight p1) idx1 p2 idx2 k =
+ acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPRight p') idx'
+acPrjCompose d (SAPJust p1) idx1 p2 idx2 k =
+ acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPJust p') idx'
+acPrjCompose SAID (SAPArrIdx p1) idx1 p2 idx2 k
+ | Dict <- styKnown (typeOf idx1) =
+ acPrjCompose SAID p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
+ k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx')
+acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k
+ | Dict <- styKnown (typeOf idx1) =
+ acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
+ k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx')