summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs21
1 files changed, 20 insertions, 1 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 8db0410..6ab4cfb 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -32,8 +32,9 @@ module CHAD (
) where
import Data.Functor.Const
+import Data.Some
import Data.Type.Bool (If)
-import Data.Type.Equality (type (==))
+import Data.Type.Equality (type (==), testEquality)
import GHC.Stack (HasCallStack)
import Analysis.Identity (ValId(..), validSplitEither)
@@ -1091,6 +1092,24 @@ drevScoped des accumMap argty argsto argids expr = case argsto of
SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (EZero ext argty))
SAccum
+ | Just (VIArr i _) <- argids
+ , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap
+ , Just Refl <- testEquality foundTy (STAccum argty)
+ , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) expr ->
+ RetScoped e0 subtape e1 sub $
+ let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in
+ ELet ext (EVar ext (STAccum argty) (WSink .> wtapebinds @> idx)) $
+ weakenExpr (autoWeak (#d (auto1 @(D2 t))
+ &. #body (subList (bindingsBinds e0) subtape)
+ &. #ac (auto1 @(TAccum a))
+ &. #tl (d2ace (select SAccum des)))
+ (#d :++: #body :++: #ac :++: #tl)
+ (#ac :++: #d :++: #body :++: #tl))
+ -- Our contribution to the binding's cotangent _here_ is
+ -- zero, because we're contributing to an earlier binding
+ -- of the same value instead.
+ (EPair ext e2 (EZero ext argty))
+
| let accumMap' = case argids of
Just (VIArr i _) -> VarMap.insert i (STAccum argty) IZ (VarMap.sink1 accumMap)
_ -> VarMap.sink1 accumMap