aboutsummaryrefslogtreecommitdiff
path: root/src/AST/UnMonoid.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST/UnMonoid.hs')
-rw-r--r--src/AST/UnMonoid.hs255
1 files changed, 0 insertions, 255 deletions
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
deleted file mode 100644
index 1712ba5..0000000
--- a/src/AST/UnMonoid.hs
+++ /dev/null
@@ -1,255 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE TypeOperators #-}
-module AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where
-
-import AST
-import AST.Sparse.Types
-import Data
-
-
--- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by
--- expanding them into their concrete implementations. Also ensure that
--- 'EAccum' has a dense sparsity.
-unMonoid :: Ex env t -> Ex env t
-unMonoid = \case
- EZero _ t e -> zero t e
- EDeepZero _ t e -> deepZero t e
- EPlus _ t a b -> plus t (unMonoid a) (unMonoid b)
- EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b)
-
- EVar _ t i -> EVar ext t i
- ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body)
- EPair _ a b -> EPair ext (unMonoid a) (unMonoid b)
- EFst _ e -> EFst ext (unMonoid e)
- ESnd _ e -> ESnd ext (unMonoid e)
- ENil _ -> ENil ext
- EInl _ t e -> EInl ext t (unMonoid e)
- EInr _ t e -> EInr ext t (unMonoid e)
- ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b)
- ENothing _ t -> ENothing ext t
- EJust _ e -> EJust ext (unMonoid e)
- EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e)
- ELNil _ t1 t2 -> ELNil ext t1 t2
- ELInl _ t e -> ELInl ext t (unMonoid e)
- ELInr _ t e -> ELInr ext t (unMonoid e)
- ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c)
- EConstArr _ n t x -> EConstArr ext n t x
- EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b)
- EMap _ a b -> EMap ext (unMonoid a) (unMonoid b)
- EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c)
- ESum1Inner _ e -> ESum1Inner ext (unMonoid e)
- EUnit _ e -> EUnit ext (unMonoid e)
- EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b)
- EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)
- EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
- EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b)
- EZip _ a b -> EZip ext (unMonoid a) (unMonoid b)
- EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
- EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
- EConst _ t x -> EConst ext t x
- EIdx0 _ e -> EIdx0 ext (unMonoid e)
- EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b)
- EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b)
- EShape _ e -> EShape ext (unMonoid e)
- EOp _ op e -> EOp ext op (unMonoid e)
- ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
- ERecompute _ e -> ERecompute ext (unMonoid e)
- EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b)
- EAccum _ t p eidx sp eval eacc ->
- accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 ->
- acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' ->
- EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc))
- EError _ t s -> EError ext t s
-
-zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t
--- don't destroy the effects!
-zero SMTNil e = ELet ext e $ ENil ext
-zero (SMTPair t1 t2) e =
- ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ)))
- (zero t2 (ESnd ext (EVar ext (typeOf e) IZ)))
-zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2)
-zero (SMTMaybe t) _ = ENothing ext (fromSMTy t)
-zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e
-zero (SMTScal t) _ = case t of
- STI32 -> EConst ext STI32 0
- STI64 -> EConst ext STI64 0
- STF32 -> EConst ext STF32 0.0
- STF64 -> EConst ext STF64 0.0
-
-deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t
-deepZero SMTNil e = elet e $ ENil ext
-deepZero (SMTPair t1 t2) e =
- ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ)))
- (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ)))
-deepZero (SMTLEither t1 t2) e =
- elcase e
- (ELNil ext (fromSMTy t1) (fromSMTy t2))
- (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ)))
- (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ)))
-deepZero (SMTMaybe t) e =
- emaybe e
- (ENothing ext (fromSMTy t))
- (EJust ext (deepZero t (evar IZ)))
-deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e
-deepZero (SMTScal t) _ = case t of
- STI32 -> EConst ext STI32 0
- STI64 -> EConst ext STI64 0
- STF32 -> EConst ext STF32 0.0
- STF64 -> EConst ext STF64 0.0
-
-plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t
--- don't destroy the effects!
-plus SMTNil a b = ELet ext a $ ELet ext (weakenExpr WSink b) $ ENil ext
-plus (SMTPair t1 t2) a b =
- let t = STPair (fromSMTy t1) (fromSMTy t2)
- in ELet ext a $
- ELet ext (weakenExpr WSink b) $
- EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ)))
- (EFst ext (EVar ext t IZ)))
- (plus t2 (ESnd ext (EVar ext t (IS IZ)))
- (ESnd ext (EVar ext t IZ)))
-plus (SMTLEither t1 t2) a b =
- let t = STLEither (fromSMTy t1) (fromSMTy t2)
- in ELet ext a $
- ELet ext (weakenExpr WSink b) $
- ELCase ext (EVar ext t (IS IZ))
- (EVar ext t IZ)
- (ELCase ext (EVar ext t (IS IZ))
- (EVar ext t (IS (IS IZ)))
- (ELInl ext (fromSMTy t2) (plus t1 (EVar ext (fromSMTy t1) (IS IZ)) (EVar ext (fromSMTy t1) IZ)))
- (EError ext t "plus l+r"))
- (ELCase ext (EVar ext t (IS IZ))
- (EVar ext t (IS (IS IZ)))
- (EError ext t "plus r+l")
- (ELInr ext (fromSMTy t1) (plus t2 (EVar ext (fromSMTy t2) (IS IZ)) (EVar ext (fromSMTy t2) IZ))))
-plus (SMTMaybe t) a b =
- ELet ext b $
- EMaybe ext
- (EVar ext (STMaybe (fromSMTy t)) IZ)
- (EJust ext
- (EMaybe ext
- (EVar ext (fromSMTy t) IZ)
- (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ))
- (EVar ext (STMaybe (fromSMTy t)) (IS IZ))))
- (weakenExpr WSink a)
-plus (SMTArr _ t) a b =
- ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ))
- a b
-plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b)
-
-onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t
-onehot typ topprj idx arg = case (typ, topprj) of
- (_, SAPHere) ->
- ELet ext arg $
- EVar ext (fromSMTy typ) IZ
-
- (SMTPair t1 t2, SAPFst prj) ->
- ELet ext idx $
- let tidx = typeOf idx in
- ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $
- let toh = fromSMTy t1 in
- EPair ext (EVar ext toh IZ)
- (zero t2 (ESnd ext (EVar ext tidx (IS IZ))))
-
- (SMTPair t1 t2, SAPSnd prj) ->
- ELet ext idx $
- let tidx = typeOf idx in
- ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $
- let toh = fromSMTy t2 in
- EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ))))
- (EVar ext toh IZ)
-
- (SMTLEither t1 t2, SAPLeft prj) ->
- ELInl ext (fromSMTy t2) (onehot t1 prj idx arg)
- (SMTLEither t1 t2, SAPRight prj) ->
- ELInr ext (fromSMTy t1) (onehot t2 prj idx arg)
-
- (SMTMaybe t1, SAPJust prj) ->
- EJust ext (onehot t1 prj idx arg)
-
- (SMTArr n t1, SAPArrIdx prj) ->
- let tidx = tTup (sreplicate n tIx)
- in ELet ext idx $
- EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $
- eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))))
- (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
- (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $
- zero t1 (EVar ext (tZeroInfo t1) IZ))
-
-accumulateSparse
- :: SMTy t -> Sparse t t' -> Ex env t'
- -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil)
- -> Ex env TNil
-accumulateSparse topty topsp arg accum = case (topty, topsp) of
- (_, s) | Just Refl <- isDense topty s ->
- accum WId SAPHere (ENil ext) arg
- (SMTScal _, SpScal) ->
- accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh
- (_, SpSparse s) ->
- emaybe arg
- (ENil ext)
- (accumulateSparse topty s (evar IZ) (\w -> accum (WPop w)))
- (_, SpAbsent) ->
- ENil ext
- (SMTPair t1 t2, SpPair s1 s2) ->
- eunPair arg $ \w1 e1 e2 ->
- elet (accumulateSparse t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $
- accumulateSparse t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj))
- (SMTLEither t1 t2, SpLEither s1 s2) ->
- elcase arg
- (ENil ext)
- (accumulateSparse t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj)))
- (accumulateSparse t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj)))
- (SMTMaybe t, SpMaybe s) ->
- emaybe arg
- (ENil ext)
- (accumulateSparse t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj)))
- (SMTArr n t, SpArr s) ->
- let tn = tTup (sreplicate n tIx) in
- elet arg $
- elet (EBuild ext n (EShape ext (evar IZ)) $
- accumulateSparse t s
- (EIdx ext (evar (IS IZ)) (EVar ext tn IZ))
- (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $
- ENil ext
-
-acPrjCompose
- :: SAIDense dense
- -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a)
- -> SAcPrj p2 b c -> Ex env (AcIdx dense p2 b)
- -> (forall p'. SAcPrj p' a c -> Ex env (AcIdx dense p' a) -> r) -> r
-acPrjCompose _ SAPHere _ p2 idx2 k = k p2 idx2
-acPrjCompose SAID (SAPFst p1) idx1 p2 idx2 k =
- acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' ->
- k (SAPFst p') idx'
-acPrjCompose SAID (SAPSnd p1) idx1 p2 idx2 k =
- acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' ->
- k (SAPSnd p') idx'
-acPrjCompose SAIS (SAPFst p1) idx1 p2 idx2 k
- | Dict <- styKnown (typeOf idx1) =
- acPrjCompose SAIS p1 (efst (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
- k (SAPFst p') (elet idx1 $ EPair ext idx' (esnd (evar IZ)))
-acPrjCompose SAIS (SAPSnd p1) idx1 p2 idx2 k
- | Dict <- styKnown (typeOf idx1) =
- acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
- k (SAPSnd p') (elet idx1 $ EPair ext (efst (evar IZ)) idx')
-acPrjCompose d (SAPLeft p1) idx1 p2 idx2 k =
- acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
- k (SAPLeft p') idx'
-acPrjCompose d (SAPRight p1) idx1 p2 idx2 k =
- acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
- k (SAPRight p') idx'
-acPrjCompose d (SAPJust p1) idx1 p2 idx2 k =
- acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
- k (SAPJust p') idx'
-acPrjCompose SAID (SAPArrIdx p1) idx1 p2 idx2 k
- | Dict <- styKnown (typeOf idx1) =
- acPrjCompose SAID p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
- k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx')
-acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k
- | Dict <- styKnown (typeOf idx1) =
- acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
- k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx')