aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/Drev.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2026-01-26 23:37:55 +0100
committerTom Smeding <tom@tomsmeding.com>2026-02-12 20:44:06 +0100
commitc2831ef0f8be71f2a72ee4eee446e2ac473fb638 (patch)
treefd5d8b38aa3920174689f14f9817f8fdaf9e3cb6 /src/CHAD/Drev.hs
parentf5b1b405fa4ba63bdffe0f2998f655f0b06534bf (diff)
Multihot cotangents WIP (doesn't work)
The idea is sound but for a smaller source language. Notes also in Obsidian, but the theory so far is that dropping support for nested arrays makes this possible, although making the result type-safe (i.e. not have partial functions in a bunch of places) would require making the lack of nested array support explicit in the embedded type system, i.e. have Accelerate-like stratification. The point is that multihots can be added heterogeneously using plusSparseS but not homogeneously with EPlus or plusSparse, because the indices might differ between the summands. Thus as long as we never need to homogeneously sum multihot cotangents, we're golden. Now the crucial observation is that we only need plus to be homogeneous on array elements. So if array elements cannot themselves be arrays, i.e. we drop support for nested arrays, no homogeneous plus of multihot array cotangents is needed, and we can have static multihots.
Diffstat (limited to 'src/CHAD/Drev.hs')
-rw-r--r--src/CHAD/Drev.hs3
1 files changed, 1 insertions, 2 deletions
diff --git a/src/CHAD/Drev.hs b/src/CHAD/Drev.hs
index 9f2921c..33dece7 100644
--- a/src/CHAD/Drev.hs
+++ b/src/CHAD/Drev.hs
@@ -1009,11 +1009,10 @@ drev des accumMap sd = \case
let smallE = unsafeWeakenWithSubenv usedSub e in
subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
case drev usedDes (VarMap.subMap subAccumUsed accumMap) sd smallE of { Ret e0 subtape _ sub e2 ->
- let subMergeUsed' = subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E subMergeUsed) in
Ret (collectBindings (desD1E des) subD1eUsed)
(subenvAll (desD1E usedDes))
(weakenExpr (wSinks (desD1E usedDes)) $ drevPrimal des e)
- (subenvCompose subMergeUsed' sub)
+ (subenvCompose (subenvD2E subMergeUsed) sub)
(letBinds (fst (weakenBindingsE (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $
weakenExpr
(autoWeak (#d (auto1 @sd)