summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2024-12-06 16:59:59 +0100
committerTom Smeding <t.j.smeding@uu.nl>2024-12-06 16:59:59 +0100
commit0ccd55fc7b3d5511935111d0e2712f452da035f4 (patch)
tree91f4625dd2bcc5db14ff319084efabff36aa1e15
parent728909852208587c3c4c63da302d22e67d5cc915 (diff)
WIP UnMonoid (to be used for compiling to C)
-rw-r--r--chad-fast.cabal1
-rw-r--r--src/AST/UnMonoid.hs110
-rw-r--r--src/CHAD.hs52
3 files changed, 111 insertions, 52 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index 893c92f..c3c2682 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -17,6 +17,7 @@ library
AST.Env
AST.Pretty
AST.Types
+ AST.UnMonoid
AST.Weaken
AST.Weaken.Auto
CHAD
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
new file mode 100644
index 0000000..1675dab
--- /dev/null
+++ b/src/AST/UnMonoid.hs
@@ -0,0 +1,110 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TypeOperators #-}
+module AST.UnMonoid where
+
+import AST
+import CHAD.Types
+import Data
+
+
+unMonoid :: Ex env t -> Ex env t
+unMonoid = \case
+ EZero t -> zero t
+ EPlus t a b -> plus t a b
+ EOneHot t i a b -> _ t i a b
+
+ EVar _ t i -> EVar ext t i
+ ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body)
+ EPair _ a b -> EPair ext (unMonoid a) (unMonoid b)
+ EFst _ e -> EFst ext (unMonoid e)
+ ESnd _ e -> ESnd ext (unMonoid e)
+ ENil _ -> ENil ext
+ EInl _ t e -> EInl ext t (unMonoid e)
+ EInr _ t e -> EInr ext t (unMonoid e)
+ ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b)
+ ENothing _ t -> ENothing ext t
+ EJust _ e -> EJust ext (unMonoid e)
+ EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e)
+ EConstArr _ n t x -> EConstArr ext n t x
+ EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b)
+ EFold1Inner _ a b c -> EFold1Inner ext (unMonoid a) (unMonoid b) (unMonoid c)
+ ESum1Inner _ e -> ESum1Inner ext (unMonoid e)
+ EUnit _ e -> EUnit ext (unMonoid e)
+ EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b)
+ EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)
+ EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
+ EConst _ t x -> EConst ext t x
+ EIdx0 _ e -> EIdx0 ext (unMonoid e)
+ EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b)
+ EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b)
+ EShape _ e -> EShape ext (unMonoid e)
+ EOp _ op e -> EOp ext op (unMonoid e)
+ ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
+ EWith a b -> EWith (unMonoid a) (unMonoid b)
+ EAccum n a b e -> EAccum n (unMonoid a) (unMonoid b) (unMonoid e)
+ EError t s -> EError t s
+
+zero :: STy t -> Ex env (D2 t)
+zero STNil = ENil ext
+zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2))
+zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2))
+zero (STMaybe t) = ENothing ext (d2 t)
+zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t)
+zero (STScal t) = case t of
+ STI32 -> ENil ext
+ STI64 -> ENil ext
+ STF32 -> EConst ext STF32 0.0
+ STF64 -> EConst ext STF64 0.0
+ STBool -> ENil ext
+zero STAccum{} = error "Accumulators not allowed in input program"
+
+plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t)
+plus STNil _ _ = ENil ext
+plus (STPair t1 t2) a b =
+ let t = STPair (d2 t1) (d2 t2)
+ in plusSparse t a b $
+ EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ)))
+ (EFst ext (EVar ext t IZ)))
+ (plus t2 (ESnd ext (EVar ext t (IS IZ)))
+ (ESnd ext (EVar ext t IZ)))
+plus (STEither t1 t2) a b =
+ let t = STEither (d2 t1) (d2 t2)
+ in plusSparse t a b $
+ ECase ext (EVar ext t (IS IZ))
+ (ECase ext (EVar ext t (IS IZ))
+ (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ)))
+ (EError t "plus l+r"))
+ (ECase ext (EVar ext t (IS IZ))
+ (EError t "plus r+l")
+ (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ))))
+plus (STMaybe t) a b =
+ plusSparse (d2 t) a b $
+ plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ)
+plus (STArr n t) a b =
+ ELet ext a $
+ ELet ext (weakenExpr WSink b) $
+ ECase
+plus (STScal t) a b = case t of
+ STI32 -> ENil ext
+ STI64 -> ENil ext
+ STF32 -> EOp ext (OAdd STF32) (EPair ext a b)
+ STF64 -> EOp ext (OAdd STF64) (EPair ext a b)
+ STBool -> ENil ext
+plus STAccum{} _ _ = error "Accumulators not allowed in input program"
+
+plusSparse :: STy a
+ -> Ex env (TMaybe a) -> Ex env (TMaybe a)
+ -> Ex (a : a : env) a
+ -> Ex env (TMaybe a)
+plusSparse t a b adder =
+ ELet ext b $
+ EMaybe ext
+ (EVar ext (STMaybe t) IZ)
+ (EJust ext
+ (EMaybe ext
+ (EVar ext t IZ)
+ (weakenExpr (WCopy (WCopy WSink)) adder)
+ (EVar ext (STMaybe t) (IS IZ))))
+ (weakenExpr WSink a)
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 04e3ac4..aa5bd4c 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -320,61 +320,9 @@ indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl
zero :: STy t -> Ex env (D2 t)
zero = EZero
--- TODO: this original definition needs to be used as the post-processing after
--- simplification, to eliminate the monoid operations from the AST
--- zero STNil = ENil ext
--- zero (STPair t1 t2) = EInl ext (STPair (d2 t1) (d2 t2)) (ENil ext)
--- zero (STEither t1 t2) = EInl ext (STEither (d2 t1) (d2 t2)) (ENil ext)
--- zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t)
--- zero (STScal t) = case t of
--- STI32 -> ENil ext
--- STI64 -> ENil ext
--- STF32 -> EConst ext STF32 0.0
--- STF64 -> EConst ext STF64 0.0
--- STBool -> ENil ext
--- zero STAccum{} = error "Accumulators not allowed in input program"
plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t)
plus = EPlus
--- plus STNil _ _ = ENil ext
--- plus (STPair t1 t2) a b =
--- let t = STPair (d2 t1) (d2 t2)
--- in plusSparse t a b $
--- EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ)))
--- (EFst ext (EVar ext t IZ)))
--- (plus t2 (ESnd ext (EVar ext t (IS IZ)))
--- (ESnd ext (EVar ext t IZ)))
--- plus (STEither t1 t2) a b =
--- let t = STEither (d2 t1) (d2 t2)
--- in plusSparse t a b $
--- ECase ext (EVar ext t (IS IZ))
--- (ECase ext (EVar ext t (IS IZ))
--- (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ)))
--- (EError t "plus l+r"))
--- (ECase ext (EVar ext t (IS IZ))
--- (EError t "plus r+l")
--- (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ))))
--- plus STArr{} _ _ = error "TODO plus on arrays"
--- plus (STScal t) a b = case t of
--- STI32 -> ENil ext
--- STI64 -> ENil ext
--- STF32 -> EOp ext (OAdd STF32) (EPair ext a b)
--- STF64 -> EOp ext (OAdd STF64) (EPair ext a b)
--- STBool -> ENil ext
--- plus STAccum{} _ _ = error "Accumulators not allowed in input program"
-
--- plusSparse :: STy a
--- -> Ex env (TEither TNil a) -> Ex env (TEither TNil a)
--- -> Ex (a : a : env) a
--- -> Ex env (TEither TNil a)
--- plusSparse t a b adder =
--- ELet ext b $
--- ECase ext (weakenExpr WSink a)
--- (EVar ext (STEither STNil t) (IS IZ))
--- (EInr ext STNil
--- (ECase ext (EVar ext (STEither STNil t) (IS IZ))
--- (EVar ext t (IS IZ))
--- (weakenExpr (WCopy (WCopy WSink)) adder)))
zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0))
zeroTup SNil = ENil ext