diff options
| author | Tom Smeding <t.j.smeding@uu.nl> | 2025-04-18 13:32:52 +0200 | 
|---|---|---|
| committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-04-18 13:32:52 +0200 | 
| commit | a624136738fb1ad3bf801723b9afbf1132fad7f0 (patch) | |
| tree | 010969bfb2cf21ed9ba19b234f132f52e5275e3b /src/CHAD.hs | |
| parent | 55fca0c5c533625262c103be1b673011ec41f2d7 (diff) | |
Some progress with accumMap
Diffstat (limited to 'src/CHAD.hs')
| -rw-r--r-- | src/CHAD.hs | 21 | 
1 files changed, 20 insertions, 1 deletions
| diff --git a/src/CHAD.hs b/src/CHAD.hs index 8db0410..6ab4cfb 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -32,8 +32,9 @@ module CHAD (  ) where  import Data.Functor.Const +import Data.Some  import Data.Type.Bool (If) -import Data.Type.Equality (type (==)) +import Data.Type.Equality (type (==), testEquality)  import GHC.Stack (HasCallStack)  import Analysis.Identity (ValId(..), validSplitEither) @@ -1091,6 +1092,24 @@ drevScoped des accumMap argty argsto argids expr = case argsto of            SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (EZero ext argty))    SAccum +    | Just (VIArr i _) <- argids +    , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap +    , Just Refl <- testEquality foundTy (STAccum argty) +    , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) expr -> +        RetScoped e0 subtape e1 sub $ +          let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in +          ELet ext (EVar ext (STAccum argty) (WSink .> wtapebinds @> idx)) $ +            weakenExpr (autoWeak (#d (auto1 @(D2 t)) +                                    &. #body (subList (bindingsBinds e0) subtape) +                                    &. #ac (auto1 @(TAccum a)) +                                    &. #tl (d2ace (select SAccum des))) +                                   (#d :++: #body :++: #ac :++: #tl) +                                   (#ac :++: #d :++: #body :++: #tl)) +                       -- Our contribution to the binding's cotangent _here_ is +                       -- zero, because we're contributing to an earlier binding +                       -- of the same value instead. +                       (EPair ext e2 (EZero ext argty)) +      | let accumMap' = case argids of                          Just (VIArr i _) -> VarMap.insert i (STAccum argty) IZ (VarMap.sink1 accumMap)                          _ -> VarMap.sink1 accumMap | 
