diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-04-27 23:34:59 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-04-27 23:34:59 +0200 |
commit | b1664532eaebdf0409ab6d93fc0ba2ef8dfbf372 (patch) | |
tree | a40c16fd082bbe4183e7b4194b8cea1408cec379 /src/CHAD | |
parent | c750f8f9f1275d49ff74297e6648e1bfc1c6d918 (diff) |
WIP revamp accumulators again: explicit monoid types
No more D2 in accumulators! Paving the way for configurable sparsity of
products and arrays. The idea is to make separate monoid types for a
"product cotangent" and an "array cotangent" that can be lowered to
either a sparse monoid or a non-sparse monoid. Downsides of this
approach: lots of API duplication.
Diffstat (limited to 'src/CHAD')
-rw-r--r-- | src/CHAD/Accum.hs | 4 | ||||
-rw-r--r-- | src/CHAD/Top.hs | 3 | ||||
-rw-r--r-- | src/CHAD/Types.hs | 44 |
3 files changed, 31 insertions, 20 deletions
diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs index b61b5ff..d8a71b5 100644 --- a/src/CHAD/Accum.hs +++ b/src/CHAD/Accum.hs @@ -10,9 +10,9 @@ import Data makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) makeAccumulators SNil e = e -makeAccumulators (t `SCons` envpro) e = +makeAccumulators (t `SCons` envpro) e | Refl <- lemZeroInfoD2 t = makeAccumulators envpro $ - EWith ext t (EZero ext t) e + EWith ext (d2M t) (EZero ext (d2M t) (ENil ext)) e uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) uninvertTup SNil _ e = EPair ext e (ENil ext) diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs index 2c01178..9e7e7f5 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Top.hs @@ -53,6 +53,7 @@ d1Identity = \case STArr _ t | Refl <- d1Identity t -> Refl STScal _ -> Refl STAccum{} -> error "Accumulators not allowed in input program" + STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl d1eIdentity :: SList STy env -> D1E env :~: env d1eIdentity SNil = Refl @@ -72,7 +73,7 @@ reassembleD2E (des `DPush` (_, _, SMerge)) e = EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ)) (EFst ext (ESnd ext (EVar ext (typeOf e) IZ))))) (ESnd ext (ESnd ext (EVar ext (typeOf e) IZ))) -reassembleD2E (des `DPush` (t, _, SDiscr)) e = EPair ext (reassembleD2E des e) (EZero ext t) +reassembleD2E (des `DPush` (t, _, SDiscr)) e = EPair ext (reassembleD2E des e) (ezeroD2 t) chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env))) chad config env (term :: Ex env t) diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index 7f49cef..74e7dbd 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -14,14 +14,16 @@ type family D1 t where D1 (TMaybe a) = TMaybe (D1 a) D1 (TArr n t) = TArr n (D1 t) D1 (TScal t) = TScal t + D1 (TLEither a b) = TLEither (D1 a) (D1 b) type family D2 t where D2 TNil = TNil D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b)) - D2 (TEither a b) = TMaybe (TEither (D2 a) (D2 b)) + D2 (TEither a b) = TLEither (D2 a) (D2 b) D2 (TMaybe t) = TMaybe (D2 t) D2 (TArr n t) = TMaybe (TArr n (D2 t)) D2 (TScal t) = D2s t + D2 (TLEither a b) = TLEither (D2 a) (D2 b) type family D2s t where D2s TI32 = TNil @@ -40,7 +42,7 @@ type family D2E env where type family D2AcE env where D2AcE '[] = '[] - D2AcE (t : env) = TAccum t : D2AcE env + D2AcE (t : env) = TAccum (D2 t) : D2AcE env d1 :: STy t -> STy (D1 t) d1 STNil = STNil @@ -50,32 +52,40 @@ d1 (STMaybe t) = STMaybe (d1 t) d1 (STArr n t) = STArr n (d1 t) d1 (STScal t) = STScal t d1 STAccum{} = error "Accumulators not allowed in input program" +d1 (STLEither a b) = STLEither (d1 a) (d1 b) d1e :: SList STy env -> SList STy (D1E env) d1e SNil = SNil d1e (t `SCons` env) = d1 t `SCons` d1e env +d2M :: STy t -> SMTy (D2 t) +d2M STNil = SMTNil +d2M (STPair a b) = SMTMaybe (SMTPair (d2M a) (d2M b)) +d2M (STEither a b) = SMTLEither (d2M a) (d2M b) +d2M (STMaybe t) = SMTMaybe (d2M t) +d2M (STArr n t) = SMTMaybe (SMTArr n (d2M t)) +d2M (STScal t) = case t of + STI32 -> SMTNil + STI64 -> SMTNil + STF32 -> SMTScal STF32 + STF64 -> SMTScal STF64 + STBool -> SMTNil +d2M STAccum{} = error "Accumulators not allowed in input program" +d2M (STLEither a b) = SMTLEither (d2M a) (d2M b) + d2 :: STy t -> STy (D2 t) -d2 STNil = STNil -d2 (STPair a b) = STMaybe (STPair (d2 a) (d2 b)) -d2 (STEither a b) = STMaybe (STEither (d2 a) (d2 b)) -d2 (STMaybe t) = STMaybe (d2 t) -d2 (STArr n t) = STMaybe (STArr n (d2 t)) -d2 (STScal t) = case t of - STI32 -> STNil - STI64 -> STNil - STF32 -> STScal STF32 - STF64 -> STScal STF64 - STBool -> STNil -d2 STAccum{} = error "Accumulators not allowed in input program" +d2 = fromSMTy . d2M + +d2eM :: SList STy env -> SList SMTy (D2E env) +d2eM SNil = SNil +d2eM (t `SCons` ts) = d2M t `SCons` d2eM ts d2e :: SList STy env -> SList STy (D2E env) -d2e SNil = SNil -d2e (t `SCons` ts) = d2 t `SCons` d2e ts +d2e = slistMap fromSMTy . d2eM d2ace :: SList STy env -> SList STy (D2AcE env) d2ace SNil = SNil -d2ace (t `SCons` ts) = STAccum t `SCons` d2ace ts +d2ace (t `SCons` ts) = STAccum (d2M t) `SCons` d2ace ts data CHADConfig = CHADConfig |