aboutsummaryrefslogtreecommitdiff
path: root/src/AST/UnMonoid.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST/UnMonoid.hs')
-rw-r--r--src/AST/UnMonoid.hs137
1 files changed, 0 insertions, 137 deletions
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
deleted file mode 100644
index 0da1afc..0000000
--- a/src/AST/UnMonoid.hs
+++ /dev/null
@@ -1,137 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE TypeOperators #-}
-module AST.UnMonoid (unMonoid, zero, plus) where
-
-import AST
-import CHAD.Types
-import Data
-
-
-unMonoid :: Ex env t -> Ex env t
-unMonoid = \case
- EZero _ t -> zero t
- EPlus _ t a b -> plus t (unMonoid a) (unMonoid b)
- EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b)
-
- EVar _ t i -> EVar ext t i
- ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body)
- EPair _ a b -> EPair ext (unMonoid a) (unMonoid b)
- EFst _ e -> EFst ext (unMonoid e)
- ESnd _ e -> ESnd ext (unMonoid e)
- ENil _ -> ENil ext
- EInl _ t e -> EInl ext t (unMonoid e)
- EInr _ t e -> EInr ext t (unMonoid e)
- ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b)
- ENothing _ t -> ENothing ext t
- EJust _ e -> EJust ext (unMonoid e)
- EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e)
- EConstArr _ n t x -> EConstArr ext n t x
- EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b)
- EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c)
- ESum1Inner _ e -> ESum1Inner ext (unMonoid e)
- EUnit _ e -> EUnit ext (unMonoid e)
- EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b)
- EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)
- EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
- EConst _ t x -> EConst ext t x
- EIdx0 _ e -> EIdx0 ext (unMonoid e)
- EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b)
- EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b)
- EShape _ e -> EShape ext (unMonoid e)
- EOp _ op e -> EOp ext op (unMonoid e)
- ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
- EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b)
- EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e)
- EError _ t s -> EError ext t s
-
-zero :: STy t -> Ex env (D2 t)
-zero STNil = ENil ext
-zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2))
-zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2))
-zero (STMaybe t) = ENothing ext (d2 t)
-zero (STArr SZ t) = ENothing ext (STArr SZ (d2 t))
-zero (STArr n t) = ENothing ext (STArr n (d2 t))
-zero (STScal t) = case t of
- STI32 -> ENil ext
- STI64 -> ENil ext
- STF32 -> EConst ext STF32 0.0
- STF64 -> EConst ext STF64 0.0
- STBool -> ENil ext
-zero STAccum{} = error "Accumulators not allowed in input program"
-
-plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t)
-plus STNil _ _ = ENil ext
-plus (STPair t1 t2) a b =
- let t = STPair (d2 t1) (d2 t2)
- in plusSparse t a b $
- EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ)))
- (EFst ext (EVar ext t IZ)))
- (plus t2 (ESnd ext (EVar ext t (IS IZ)))
- (ESnd ext (EVar ext t IZ)))
-plus (STEither t1 t2) a b =
- let t = STEither (d2 t1) (d2 t2)
- in plusSparse t a b $
- ECase ext (EVar ext t (IS IZ))
- (ECase ext (EVar ext t (IS IZ))
- (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ)))
- (EError ext t "plus l+r"))
- (ECase ext (EVar ext t (IS IZ))
- (EError ext t "plus r+l")
- (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ))))
-plus (STMaybe t) a b =
- plusSparse (d2 t) a b $
- plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ)
-plus (STArr n t) a b =
- plusSparse (STArr n (d2 t)) a b $
- eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) (IS IZ))))
- (EVar ext (STArr n (d2 t)) IZ)
- (eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) IZ)))
- (EVar ext (STArr n (d2 t)) (IS IZ))
- (ezipWith (plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ))
- (EVar ext (STArr n (d2 t)) (IS IZ))
- (EVar ext (STArr n (d2 t)) IZ)))
-plus (STScal t) a b = case t of
- STI32 -> ENil ext
- STI64 -> ENil ext
- STF32 -> EOp ext (OAdd STF32) (EPair ext a b)
- STF64 -> EOp ext (OAdd STF64) (EPair ext a b)
- STBool -> ENil ext
-plus STAccum{} _ _ = error "Accumulators not allowed in input program"
-
-plusSparse :: STy a
- -> Ex env (TMaybe a) -> Ex env (TMaybe a)
- -> Ex (a : a : env) a
- -> Ex env (TMaybe a)
-plusSparse t a b adder =
- ELet ext b $
- EMaybe ext
- (EVar ext (STMaybe t) IZ)
- (EJust ext
- (EMaybe ext
- (EVar ext t IZ)
- (weakenExpr (WCopy (WCopy WSink)) adder)
- (EVar ext (STMaybe t) (IS IZ))))
- (weakenExpr WSink a)
-
-onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t)
-onehot typ topprj idx arg = case (typ, topprj) of
- (_, SAPHere) -> arg
-
- (STPair t1 t2, SAPFst prj) -> EJust ext (EPair ext (onehot t1 prj idx arg) (zero t2))
- (STPair t1 t2, SAPSnd prj) -> EJust ext (EPair ext (zero t1) (onehot t2 prj idx arg))
-
- (STEither t1 t2, SAPLeft prj) -> EJust ext (EInl ext (d2 t2) (onehot t1 prj idx arg))
- (STEither t1 t2, SAPRight prj) -> EJust ext (EInr ext (d2 t1) (onehot t2 prj idx arg))
-
- (STMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg)
-
- (STArr n t1, SAPArrIdx prj _) ->
- let tidx = tTup (sreplicate n tIx)
- in ELet ext idx $
- EJust ext $
- EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $
- eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))))
- (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
- (zero t1)