summaryrefslogtreecommitdiff
path: root/src/AST/UnMonoid.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-04-27 23:34:59 +0200
committerTom Smeding <tom@tomsmeding.com>2025-04-27 23:34:59 +0200
commitb1664532eaebdf0409ab6d93fc0ba2ef8dfbf372 (patch)
treea40c16fd082bbe4183e7b4194b8cea1408cec379 /src/AST/UnMonoid.hs
parentc750f8f9f1275d49ff74297e6648e1bfc1c6d918 (diff)
WIP revamp accumulators again: explicit monoid types
No more D2 in accumulators! Paving the way for configurable sparsity of products and arrays. The idea is to make separate monoid types for a "product cotangent" and an "array cotangent" that can be lowered to either a sparse monoid or a non-sparse monoid. Downsides of this approach: lots of API duplication.
Diffstat (limited to 'src/AST/UnMonoid.hs')
-rw-r--r--src/AST/UnMonoid.hs145
1 files changed, 76 insertions, 69 deletions
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index 0da1afc..3d5f544 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -5,13 +5,14 @@
module AST.UnMonoid (unMonoid, zero, plus) where
import AST
-import CHAD.Types
import Data
+-- | Remove 'EZero', 'EPlus' and 'EOneHot' from the program by expanding them
+-- into their concrete implementations.
unMonoid :: Ex env t -> Ex env t
unMonoid = \case
- EZero _ t -> zero t
+ EZero _ t e -> zero t e
EPlus _ t a b -> plus t (unMonoid a) (unMonoid b)
EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b)
@@ -27,6 +28,10 @@ unMonoid = \case
ENothing _ t -> ENothing ext t
EJust _ e -> EJust ext (unMonoid e)
EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e)
+ ELNil _ t1 t2 -> ELNil ext t1 t2
+ ELInl _ t e -> ELInl ext t (unMonoid e)
+ ELInr _ t e -> ELInr ext t (unMonoid e)
+ ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c)
EConstArr _ n t x -> EConstArr ext n t x
EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b)
EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c)
@@ -46,92 +51,94 @@ unMonoid = \case
EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e)
EError _ t s -> EError ext t s
-zero :: STy t -> Ex env (D2 t)
-zero STNil = ENil ext
-zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2))
-zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2))
-zero (STMaybe t) = ENothing ext (d2 t)
-zero (STArr SZ t) = ENothing ext (STArr SZ (d2 t))
-zero (STArr n t) = ENothing ext (STArr n (d2 t))
-zero (STScal t) = case t of
- STI32 -> ENil ext
- STI64 -> ENil ext
+zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t
+zero SMTNil _ = ENil ext
+zero (SMTPair t1 t2) e =
+ ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ)))
+ (zero t2 (ESnd ext (EVar ext (typeOf e) IZ)))
+zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2)
+zero (SMTMaybe t) _ = ENothing ext (fromSMTy t)
+zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e
+zero (SMTScal t) _ = case t of
+ STI32 -> EConst ext STI32 0
+ STI64 -> EConst ext STI64 0
STF32 -> EConst ext STF32 0.0
STF64 -> EConst ext STF64 0.0
- STBool -> ENil ext
-zero STAccum{} = error "Accumulators not allowed in input program"
-plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t)
-plus STNil _ _ = ENil ext
-plus (STPair t1 t2) a b =
- let t = STPair (d2 t1) (d2 t2)
- in plusSparse t a b $
+plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t
+plus SMTNil _ _ = ENil ext
+plus (SMTPair t1 t2) a b =
+ let t = STPair (fromSMTy t1) (fromSMTy t2)
+ in ELet ext a $
+ ELet ext (weakenExpr WSink b) $
EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ)))
(EFst ext (EVar ext t IZ)))
(plus t2 (ESnd ext (EVar ext t (IS IZ)))
(ESnd ext (EVar ext t IZ)))
-plus (STEither t1 t2) a b =
- let t = STEither (d2 t1) (d2 t2)
- in plusSparse t a b $
- ECase ext (EVar ext t (IS IZ))
- (ECase ext (EVar ext t (IS IZ))
- (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ)))
+plus (SMTLEither t1 t2) a b =
+ let t = STLEither (fromSMTy t1) (fromSMTy t2)
+ in ELet ext a $
+ ELet ext (weakenExpr WSink b) $
+ ELCase ext (EVar ext t (IS IZ))
+ (EVar ext t IZ)
+ (ELCase ext (EVar ext t (IS IZ))
+ (EVar ext t (IS (IS IZ)))
+ (ELInl ext (fromSMTy t2) (plus t1 (EVar ext (fromSMTy t1) (IS IZ)) (EVar ext (fromSMTy t1) IZ)))
(EError ext t "plus l+r"))
- (ECase ext (EVar ext t (IS IZ))
+ (ELCase ext (EVar ext t (IS IZ))
+ (EVar ext t (IS (IS IZ)))
(EError ext t "plus r+l")
- (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ))))
-plus (STMaybe t) a b =
- plusSparse (d2 t) a b $
- plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ)
-plus (STArr n t) a b =
- plusSparse (STArr n (d2 t)) a b $
- eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) (IS IZ))))
- (EVar ext (STArr n (d2 t)) IZ)
- (eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) IZ)))
- (EVar ext (STArr n (d2 t)) (IS IZ))
- (ezipWith (plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ))
- (EVar ext (STArr n (d2 t)) (IS IZ))
- (EVar ext (STArr n (d2 t)) IZ)))
-plus (STScal t) a b = case t of
- STI32 -> ENil ext
- STI64 -> ENil ext
- STF32 -> EOp ext (OAdd STF32) (EPair ext a b)
- STF64 -> EOp ext (OAdd STF64) (EPair ext a b)
- STBool -> ENil ext
-plus STAccum{} _ _ = error "Accumulators not allowed in input program"
-
-plusSparse :: STy a
- -> Ex env (TMaybe a) -> Ex env (TMaybe a)
- -> Ex (a : a : env) a
- -> Ex env (TMaybe a)
-plusSparse t a b adder =
+ (ELInr ext (fromSMTy t1) (plus t2 (EVar ext (fromSMTy t2) (IS IZ)) (EVar ext (fromSMTy t2) IZ))))
+plus (SMTMaybe t) a b =
ELet ext b $
EMaybe ext
- (EVar ext (STMaybe t) IZ)
+ (EVar ext (STMaybe (fromSMTy t)) IZ)
(EJust ext
(EMaybe ext
- (EVar ext t IZ)
- (weakenExpr (WCopy (WCopy WSink)) adder)
- (EVar ext (STMaybe t) (IS IZ))))
+ (EVar ext (fromSMTy t) IZ)
+ (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ))
+ (EVar ext (STMaybe (fromSMTy t)) (IS IZ))))
(weakenExpr WSink a)
+plus (SMTArr _ t) a b =
+ ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ))
+ a b
+plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b)
-onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t)
+onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env t
onehot typ topprj idx arg = case (typ, topprj) of
- (_, SAPHere) -> arg
+ (_, SAPHere) ->
+ ELet ext arg $
+ EVar ext (fromSMTy typ) IZ
- (STPair t1 t2, SAPFst prj) -> EJust ext (EPair ext (onehot t1 prj idx arg) (zero t2))
- (STPair t1 t2, SAPSnd prj) -> EJust ext (EPair ext (zero t1) (onehot t2 prj idx arg))
+ (SMTPair t1 t2, SAPFst prj) ->
+ ELet ext idx $
+ let tidx = typeOf idx in
+ ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $
+ let toh = fromSMTy t1 in
+ EPair ext (EVar ext toh IZ)
+ (zero t2 (ESnd ext (EVar ext tidx (IS IZ))))
+
+ (SMTPair t1 t2, SAPSnd prj) ->
+ ELet ext idx $
+ let tidx = typeOf idx in
+ ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $
+ let toh = fromSMTy t2 in
+ EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ))))
+ (EVar ext toh IZ)
- (STEither t1 t2, SAPLeft prj) -> EJust ext (EInl ext (d2 t2) (onehot t1 prj idx arg))
- (STEither t1 t2, SAPRight prj) -> EJust ext (EInr ext (d2 t1) (onehot t2 prj idx arg))
+ (SMTLEither t1 t2, SAPLeft prj) ->
+ ELInl ext (fromSMTy t2) (onehot t1 prj idx arg)
+ (SMTLEither t1 t2, SAPRight prj) ->
+ ELInr ext (fromSMTy t1) (onehot t2 prj idx arg)
- (STMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg)
+ (SMTMaybe t1, SAPJust prj) ->
+ EJust ext (onehot t1 prj idx arg)
- (STArr n t1, SAPArrIdx prj _) ->
+ (SMTArr n t1, SAPArrIdx prj) ->
let tidx = tTup (sreplicate n tIx)
in ELet ext idx $
- EJust ext $
- EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $
- eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))))
- (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
- (zero t1)
+ EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $
+ eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))))
+ (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
+ (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $
+ zero t1 (EVar ext (tZeroInfo t1) IZ))