summaryrefslogtreecommitdiff
path: root/src/ForwardAD
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-04-27 23:34:59 +0200
committerTom Smeding <tom@tomsmeding.com>2025-04-27 23:34:59 +0200
commitb1664532eaebdf0409ab6d93fc0ba2ef8dfbf372 (patch)
treea40c16fd082bbe4183e7b4194b8cea1408cec379 /src/ForwardAD
parentc750f8f9f1275d49ff74297e6648e1bfc1c6d918 (diff)
WIP revamp accumulators again: explicit monoid types
No more D2 in accumulators! Paving the way for configurable sparsity of products and arrays. The idea is to make separate monoid types for a "product cotangent" and an "array cotangent" that can be lowered to either a sparse monoid or a non-sparse monoid. Downsides of this approach: lots of API duplication.
Diffstat (limited to 'src/ForwardAD')
-rw-r--r--src/ForwardAD/DualNumbers.hs4
-rw-r--r--src/ForwardAD/DualNumbers/Types.hs2
2 files changed, 6 insertions, 0 deletions
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index 2f94076..ebc70d7 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -143,6 +143,10 @@ dfwdDN = \case
ENothing _ t -> ENothing ext (dn t)
EJust _ e -> EJust ext (dfwdDN e)
EMaybe _ e a b -> EMaybe ext (dfwdDN e) (dfwdDN a) (dfwdDN b)
+ ELNil _ t1 t2 -> ELNil ext (dn t1) (dn t2)
+ ELInl _ t e -> ELInl ext (dn t) (dfwdDN e)
+ ELInr _ t e -> ELInr ext (dn t) (dfwdDN e)
+ ELCase _ e a b c -> ELCase ext (dfwdDN e) (dfwdDN a) (dfwdDN b) (dfwdDN c)
EConstArr _ n t x -> scalTyCase t
(emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0))
(EConstArr ext n t x))
diff --git a/src/ForwardAD/DualNumbers/Types.hs b/src/ForwardAD/DualNumbers/Types.hs
index fba92d0..3c76cbe 100644
--- a/src/ForwardAD/DualNumbers/Types.hs
+++ b/src/ForwardAD/DualNumbers/Types.hs
@@ -15,6 +15,7 @@ type family DN t where
DN (TMaybe t) = TMaybe (DN t)
DN (TArr n t) = TArr n (DN t)
DN (TScal t) = DNS t
+ DN (TLEither a b) = TLEither (DN a) (DN b)
type family DNS t where
DNS TF32 = TPair (TScal TF32) (TScal TF32)
@@ -40,6 +41,7 @@ dn (STScal t) = case t of
STI64 -> STScal STI64
STBool -> STScal STBool
dn STAccum{} = error "Accum in source program"
+dn (STLEither a b) = STLEither (dn a) (dn b)
dne :: SList STy env -> SList STy (DNE env)
dne SNil = SNil