summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/CHAD.hs21
-rw-r--r--src/Simplify.hs10
2 files changed, 29 insertions, 2 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 8db0410..6ab4cfb 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -32,8 +32,9 @@ module CHAD (
) where
import Data.Functor.Const
+import Data.Some
import Data.Type.Bool (If)
-import Data.Type.Equality (type (==))
+import Data.Type.Equality (type (==), testEquality)
import GHC.Stack (HasCallStack)
import Analysis.Identity (ValId(..), validSplitEither)
@@ -1091,6 +1092,24 @@ drevScoped des accumMap argty argsto argids expr = case argsto of
SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (EZero ext argty))
SAccum
+ | Just (VIArr i _) <- argids
+ , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap
+ , Just Refl <- testEquality foundTy (STAccum argty)
+ , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) expr ->
+ RetScoped e0 subtape e1 sub $
+ let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in
+ ELet ext (EVar ext (STAccum argty) (WSink .> wtapebinds @> idx)) $
+ weakenExpr (autoWeak (#d (auto1 @(D2 t))
+ &. #body (subList (bindingsBinds e0) subtape)
+ &. #ac (auto1 @(TAccum a))
+ &. #tl (d2ace (select SAccum des)))
+ (#d :++: #body :++: #ac :++: #tl)
+ (#ac :++: #d :++: #body :++: #tl))
+ -- Our contribution to the binding's cotangent _here_ is
+ -- zero, because we're contributing to an earlier binding
+ -- of the same value instead.
+ (EPair ext e2 (EZero ext argty))
+
| let accumMap' = case argids of
Just (VIArr i _) -> VarMap.insert i (STAccum argty) IZ (VarMap.sink1 accumMap)
_ -> VarMap.sink1 accumMap
diff --git a/src/Simplify.hs b/src/Simplify.hs
index f5b7d15..cb835aa 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -100,12 +100,20 @@ simplify' = \case
EMaybe _ e1 _ ENothing{} -> acted $ simplify' e1
EMaybe _ _ e1 (EJust _ e2) -> acted $ simplify' $ ELet ext e2 e1
- -- let floating to facilitate beta reduction
+ -- let floating
EFst _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EFst ext body))
ESnd _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (ESnd ext body))
ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2)))
EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body))
EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e)))
+ EAccum _ t p e1 (ELet _ rhs body) acc ->
+ acted $ simplify' $
+ ELet ext rhs $
+ EAccum ext t p (weakenExpr WSink e1) body (weakenExpr WSink acc)
+
+ -- let () = e in () ~> e
+ ELet _ e1 (ENil _) | STNil <- typeOf e1 ->
+ acted $ simplify' e1
-- projection down-commuting
EFst _ (ECase _ e1 e2 e3) ->