summaryrefslogtreecommitdiff
path: root/src/CHAD
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-04-27 23:34:59 +0200
committerTom Smeding <tom@tomsmeding.com>2025-04-27 23:34:59 +0200
commitb1664532eaebdf0409ab6d93fc0ba2ef8dfbf372 (patch)
treea40c16fd082bbe4183e7b4194b8cea1408cec379 /src/CHAD
parentc750f8f9f1275d49ff74297e6648e1bfc1c6d918 (diff)
WIP revamp accumulators again: explicit monoid types
No more D2 in accumulators! Paving the way for configurable sparsity of products and arrays. The idea is to make separate monoid types for a "product cotangent" and an "array cotangent" that can be lowered to either a sparse monoid or a non-sparse monoid. Downsides of this approach: lots of API duplication.
Diffstat (limited to 'src/CHAD')
-rw-r--r--src/CHAD/Accum.hs4
-rw-r--r--src/CHAD/Top.hs3
-rw-r--r--src/CHAD/Types.hs44
3 files changed, 31 insertions, 20 deletions
diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs
index b61b5ff..d8a71b5 100644
--- a/src/CHAD/Accum.hs
+++ b/src/CHAD/Accum.hs
@@ -10,9 +10,9 @@ import Data
makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
makeAccumulators SNil e = e
-makeAccumulators (t `SCons` envpro) e =
+makeAccumulators (t `SCons` envpro) e | Refl <- lemZeroInfoD2 t =
makeAccumulators envpro $
- EWith ext t (EZero ext t) e
+ EWith ext (d2M t) (EZero ext (d2M t) (ENil ext)) e
uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))
uninvertTup SNil _ e = EPair ext e (ENil ext)
diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs
index 2c01178..9e7e7f5 100644
--- a/src/CHAD/Top.hs
+++ b/src/CHAD/Top.hs
@@ -53,6 +53,7 @@ d1Identity = \case
STArr _ t | Refl <- d1Identity t -> Refl
STScal _ -> Refl
STAccum{} -> error "Accumulators not allowed in input program"
+ STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
d1eIdentity :: SList STy env -> D1E env :~: env
d1eIdentity SNil = Refl
@@ -72,7 +73,7 @@ reassembleD2E (des `DPush` (_, _, SMerge)) e =
EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ))
(EFst ext (ESnd ext (EVar ext (typeOf e) IZ)))))
(ESnd ext (ESnd ext (EVar ext (typeOf e) IZ)))
-reassembleD2E (des `DPush` (t, _, SDiscr)) e = EPair ext (reassembleD2E des e) (EZero ext t)
+reassembleD2E (des `DPush` (t, _, SDiscr)) e = EPair ext (reassembleD2E des e) (ezeroD2 t)
chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env)))
chad config env (term :: Ex env t)
diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs
index 7f49cef..74e7dbd 100644
--- a/src/CHAD/Types.hs
+++ b/src/CHAD/Types.hs
@@ -14,14 +14,16 @@ type family D1 t where
D1 (TMaybe a) = TMaybe (D1 a)
D1 (TArr n t) = TArr n (D1 t)
D1 (TScal t) = TScal t
+ D1 (TLEither a b) = TLEither (D1 a) (D1 b)
type family D2 t where
D2 TNil = TNil
D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b))
- D2 (TEither a b) = TMaybe (TEither (D2 a) (D2 b))
+ D2 (TEither a b) = TLEither (D2 a) (D2 b)
D2 (TMaybe t) = TMaybe (D2 t)
D2 (TArr n t) = TMaybe (TArr n (D2 t))
D2 (TScal t) = D2s t
+ D2 (TLEither a b) = TLEither (D2 a) (D2 b)
type family D2s t where
D2s TI32 = TNil
@@ -40,7 +42,7 @@ type family D2E env where
type family D2AcE env where
D2AcE '[] = '[]
- D2AcE (t : env) = TAccum t : D2AcE env
+ D2AcE (t : env) = TAccum (D2 t) : D2AcE env
d1 :: STy t -> STy (D1 t)
d1 STNil = STNil
@@ -50,32 +52,40 @@ d1 (STMaybe t) = STMaybe (d1 t)
d1 (STArr n t) = STArr n (d1 t)
d1 (STScal t) = STScal t
d1 STAccum{} = error "Accumulators not allowed in input program"
+d1 (STLEither a b) = STLEither (d1 a) (d1 b)
d1e :: SList STy env -> SList STy (D1E env)
d1e SNil = SNil
d1e (t `SCons` env) = d1 t `SCons` d1e env
+d2M :: STy t -> SMTy (D2 t)
+d2M STNil = SMTNil
+d2M (STPair a b) = SMTMaybe (SMTPair (d2M a) (d2M b))
+d2M (STEither a b) = SMTLEither (d2M a) (d2M b)
+d2M (STMaybe t) = SMTMaybe (d2M t)
+d2M (STArr n t) = SMTMaybe (SMTArr n (d2M t))
+d2M (STScal t) = case t of
+ STI32 -> SMTNil
+ STI64 -> SMTNil
+ STF32 -> SMTScal STF32
+ STF64 -> SMTScal STF64
+ STBool -> SMTNil
+d2M STAccum{} = error "Accumulators not allowed in input program"
+d2M (STLEither a b) = SMTLEither (d2M a) (d2M b)
+
d2 :: STy t -> STy (D2 t)
-d2 STNil = STNil
-d2 (STPair a b) = STMaybe (STPair (d2 a) (d2 b))
-d2 (STEither a b) = STMaybe (STEither (d2 a) (d2 b))
-d2 (STMaybe t) = STMaybe (d2 t)
-d2 (STArr n t) = STMaybe (STArr n (d2 t))
-d2 (STScal t) = case t of
- STI32 -> STNil
- STI64 -> STNil
- STF32 -> STScal STF32
- STF64 -> STScal STF64
- STBool -> STNil
-d2 STAccum{} = error "Accumulators not allowed in input program"
+d2 = fromSMTy . d2M
+
+d2eM :: SList STy env -> SList SMTy (D2E env)
+d2eM SNil = SNil
+d2eM (t `SCons` ts) = d2M t `SCons` d2eM ts
d2e :: SList STy env -> SList STy (D2E env)
-d2e SNil = SNil
-d2e (t `SCons` ts) = d2 t `SCons` d2e ts
+d2e = slistMap fromSMTy . d2eM
d2ace :: SList STy env -> SList STy (D2AcE env)
d2ace SNil = SNil
-d2ace (t `SCons` ts) = STAccum t `SCons` d2ace ts
+d2ace (t `SCons` ts) = STAccum (d2M t) `SCons` d2ace ts
data CHADConfig = CHADConfig