aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/Drev/Accum.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-27 21:30:17 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-27 21:30:17 +0100
commit20f7d7be13cd7869b338f98d1ab3fd33e8bbfb3e (patch)
treea21c90034a02cdeb7240563dbbab355e49622d0a /src/CHAD/Drev/Accum.hs
parentae634c056b500a568b2d89b7f8e225404a2c0c62 (diff)
WIP user-specified custom typesuser-types
The big roadblock encountered is that accumulation wants addition of monoids to be elementwise float addition; this fundamentally clashes with the concept of a user type with a custom zero and plus.
Diffstat (limited to 'src/CHAD/Drev/Accum.hs')
-rw-r--r--src/CHAD/Drev/Accum.hs2
1 files changed, 2 insertions, 0 deletions
diff --git a/src/CHAD/Drev/Accum.hs b/src/CHAD/Drev/Accum.hs
index 6f25f11..43305e6 100644
--- a/src/CHAD/Drev/Accum.hs
+++ b/src/CHAD/Drev/Accum.hs
@@ -21,6 +21,7 @@ d2zeroInfo STMaybe{} _ = ENil ext
d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e
d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext
d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program"
+d2zeroInfo (STUser t) e = euserD2ZeroInfo t (EUnUser ext e)
d2deepZeroInfo :: STy t -> Ex env (D1 t) -> Ex env (DeepZeroInfo (D2 t))
d2deepZeroInfo STNil _ = ENil ext
@@ -43,6 +44,7 @@ d2deepZeroInfo (STMaybe a) e =
d2deepZeroInfo (STArr _ t) e = emap (d2deepZeroInfo t (EVar ext (d1 t) IZ)) e
d2deepZeroInfo (STScal t) _ | Refl <- lemDeepZeroInfoScal t = ENil ext
d2deepZeroInfo STAccum{} _ = error "accumulators not allowed in source program"
+d2deepZeroInfo (STUser t) e = euserD2DeepZeroInfo t (EUnUser ext e)
-- The weakening is necessary because we need to initialise the created
-- accumulators with zeros. Those zeros are deep and need full primals. This