diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2024-12-06 16:59:59 +0100 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2024-12-06 16:59:59 +0100 |
commit | 0ccd55fc7b3d5511935111d0e2712f452da035f4 (patch) | |
tree | 91f4625dd2bcc5db14ff319084efabff36aa1e15 | |
parent | 728909852208587c3c4c63da302d22e67d5cc915 (diff) |
WIP UnMonoid (to be used for compiling to C)
-rw-r--r-- | chad-fast.cabal | 1 | ||||
-rw-r--r-- | src/AST/UnMonoid.hs | 110 | ||||
-rw-r--r-- | src/CHAD.hs | 52 |
3 files changed, 111 insertions, 52 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index 893c92f..c3c2682 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -17,6 +17,7 @@ library AST.Env AST.Pretty AST.Types + AST.UnMonoid AST.Weaken AST.Weaken.Auto CHAD diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs new file mode 100644 index 0000000..1675dab --- /dev/null +++ b/src/AST/UnMonoid.hs @@ -0,0 +1,110 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TypeOperators #-} +module AST.UnMonoid where + +import AST +import CHAD.Types +import Data + + +unMonoid :: Ex env t -> Ex env t +unMonoid = \case + EZero t -> zero t + EPlus t a b -> plus t a b + EOneHot t i a b -> _ t i a b + + EVar _ t i -> EVar ext t i + ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) + EPair _ a b -> EPair ext (unMonoid a) (unMonoid b) + EFst _ e -> EFst ext (unMonoid e) + ESnd _ e -> ESnd ext (unMonoid e) + ENil _ -> ENil ext + EInl _ t e -> EInl ext t (unMonoid e) + EInr _ t e -> EInr ext t (unMonoid e) + ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b) + ENothing _ t -> ENothing ext t + EJust _ e -> EJust ext (unMonoid e) + EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e) + EConstArr _ n t x -> EConstArr ext n t x + EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) + EFold1Inner _ a b c -> EFold1Inner ext (unMonoid a) (unMonoid b) (unMonoid c) + ESum1Inner _ e -> ESum1Inner ext (unMonoid e) + EUnit _ e -> EUnit ext (unMonoid e) + EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) + EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) + EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) + EConst _ t x -> EConst ext t x + EIdx0 _ e -> EIdx0 ext (unMonoid e) + EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) + EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b) + EShape _ e -> EShape ext (unMonoid e) + EOp _ op e -> EOp ext op (unMonoid e) + ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) + EWith a b -> EWith (unMonoid a) (unMonoid b) + EAccum n a b e -> EAccum n (unMonoid a) (unMonoid b) (unMonoid e) + EError t s -> EError t s + +zero :: STy t -> Ex env (D2 t) +zero STNil = ENil ext +zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2)) +zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2)) +zero (STMaybe t) = ENothing ext (d2 t) +zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t) +zero (STScal t) = case t of + STI32 -> ENil ext + STI64 -> ENil ext + STF32 -> EConst ext STF32 0.0 + STF64 -> EConst ext STF64 0.0 + STBool -> ENil ext +zero STAccum{} = error "Accumulators not allowed in input program" + +plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t) +plus STNil _ _ = ENil ext +plus (STPair t1 t2) a b = + let t = STPair (d2 t1) (d2 t2) + in plusSparse t a b $ + EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ))) + (EFst ext (EVar ext t IZ))) + (plus t2 (ESnd ext (EVar ext t (IS IZ))) + (ESnd ext (EVar ext t IZ))) +plus (STEither t1 t2) a b = + let t = STEither (d2 t1) (d2 t2) + in plusSparse t a b $ + ECase ext (EVar ext t (IS IZ)) + (ECase ext (EVar ext t (IS IZ)) + (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ))) + (EError t "plus l+r")) + (ECase ext (EVar ext t (IS IZ)) + (EError t "plus r+l") + (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ)))) +plus (STMaybe t) a b = + plusSparse (d2 t) a b $ + plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ) +plus (STArr n t) a b = + ELet ext a $ + ELet ext (weakenExpr WSink b) $ + ECase +plus (STScal t) a b = case t of + STI32 -> ENil ext + STI64 -> ENil ext + STF32 -> EOp ext (OAdd STF32) (EPair ext a b) + STF64 -> EOp ext (OAdd STF64) (EPair ext a b) + STBool -> ENil ext +plus STAccum{} _ _ = error "Accumulators not allowed in input program" + +plusSparse :: STy a + -> Ex env (TMaybe a) -> Ex env (TMaybe a) + -> Ex (a : a : env) a + -> Ex env (TMaybe a) +plusSparse t a b adder = + ELet ext b $ + EMaybe ext + (EVar ext (STMaybe t) IZ) + (EJust ext + (EMaybe ext + (EVar ext t IZ) + (weakenExpr (WCopy (WCopy WSink)) adder) + (EVar ext (STMaybe t) (IS IZ)))) + (weakenExpr WSink a) diff --git a/src/CHAD.hs b/src/CHAD.hs index 04e3ac4..aa5bd4c 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -320,61 +320,9 @@ indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl zero :: STy t -> Ex env (D2 t) zero = EZero --- TODO: this original definition needs to be used as the post-processing after --- simplification, to eliminate the monoid operations from the AST --- zero STNil = ENil ext --- zero (STPair t1 t2) = EInl ext (STPair (d2 t1) (d2 t2)) (ENil ext) --- zero (STEither t1 t2) = EInl ext (STEither (d2 t1) (d2 t2)) (ENil ext) --- zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t) --- zero (STScal t) = case t of --- STI32 -> ENil ext --- STI64 -> ENil ext --- STF32 -> EConst ext STF32 0.0 --- STF64 -> EConst ext STF64 0.0 --- STBool -> ENil ext --- zero STAccum{} = error "Accumulators not allowed in input program" plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t) plus = EPlus --- plus STNil _ _ = ENil ext --- plus (STPair t1 t2) a b = --- let t = STPair (d2 t1) (d2 t2) --- in plusSparse t a b $ --- EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ))) --- (EFst ext (EVar ext t IZ))) --- (plus t2 (ESnd ext (EVar ext t (IS IZ))) --- (ESnd ext (EVar ext t IZ))) --- plus (STEither t1 t2) a b = --- let t = STEither (d2 t1) (d2 t2) --- in plusSparse t a b $ --- ECase ext (EVar ext t (IS IZ)) --- (ECase ext (EVar ext t (IS IZ)) --- (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ))) --- (EError t "plus l+r")) --- (ECase ext (EVar ext t (IS IZ)) --- (EError t "plus r+l") --- (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ)))) --- plus STArr{} _ _ = error "TODO plus on arrays" --- plus (STScal t) a b = case t of --- STI32 -> ENil ext --- STI64 -> ENil ext --- STF32 -> EOp ext (OAdd STF32) (EPair ext a b) --- STF64 -> EOp ext (OAdd STF64) (EPair ext a b) --- STBool -> ENil ext --- plus STAccum{} _ _ = error "Accumulators not allowed in input program" - --- plusSparse :: STy a --- -> Ex env (TEither TNil a) -> Ex env (TEither TNil a) --- -> Ex (a : a : env) a --- -> Ex env (TEither TNil a) --- plusSparse t a b adder = --- ELet ext b $ --- ECase ext (weakenExpr WSink a) --- (EVar ext (STEither STNil t) (IS IZ)) --- (EInr ext STNil --- (ECase ext (EVar ext (STEither STNil t) (IS IZ)) --- (EVar ext t (IS IZ)) --- (weakenExpr (WCopy (WCopy WSink)) adder))) zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0)) zeroTup SNil = ENil ext |