diff options
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 21 |
1 files changed, 20 insertions, 1 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index 8db0410..6ab4cfb 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -32,8 +32,9 @@ module CHAD ( ) where import Data.Functor.Const +import Data.Some import Data.Type.Bool (If) -import Data.Type.Equality (type (==)) +import Data.Type.Equality (type (==), testEquality) import GHC.Stack (HasCallStack) import Analysis.Identity (ValId(..), validSplitEither) @@ -1091,6 +1092,24 @@ drevScoped des accumMap argty argsto argids expr = case argsto of SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (EZero ext argty)) SAccum + | Just (VIArr i _) <- argids + , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap + , Just Refl <- testEquality foundTy (STAccum argty) + , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) expr -> + RetScoped e0 subtape e1 sub $ + let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in + ELet ext (EVar ext (STAccum argty) (WSink .> wtapebinds @> idx)) $ + weakenExpr (autoWeak (#d (auto1 @(D2 t)) + &. #body (subList (bindingsBinds e0) subtape) + &. #ac (auto1 @(TAccum a)) + &. #tl (d2ace (select SAccum des))) + (#d :++: #body :++: #ac :++: #tl) + (#ac :++: #d :++: #body :++: #tl)) + -- Our contribution to the binding's cotangent _here_ is + -- zero, because we're contributing to an earlier binding + -- of the same value instead. + (EPair ext e2 (EZero ext argty)) + | let accumMap' = case argids of Just (VIArr i _) -> VarMap.insert i (STAccum argty) IZ (VarMap.sink1 accumMap) _ -> VarMap.sink1 accumMap |