From a624136738fb1ad3bf801723b9afbf1132fad7f0 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Fri, 18 Apr 2025 13:32:52 +0200 Subject: Some progress with accumMap --- src/CHAD.hs | 21 ++++++++++++++++++++- src/Simplify.hs | 10 +++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) (limited to 'src') diff --git a/src/CHAD.hs b/src/CHAD.hs index 8db0410..6ab4cfb 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -32,8 +32,9 @@ module CHAD ( ) where import Data.Functor.Const +import Data.Some import Data.Type.Bool (If) -import Data.Type.Equality (type (==)) +import Data.Type.Equality (type (==), testEquality) import GHC.Stack (HasCallStack) import Analysis.Identity (ValId(..), validSplitEither) @@ -1091,6 +1092,24 @@ drevScoped des accumMap argty argsto argids expr = case argsto of SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (EZero ext argty)) SAccum + | Just (VIArr i _) <- argids + , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap + , Just Refl <- testEquality foundTy (STAccum argty) + , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) expr -> + RetScoped e0 subtape e1 sub $ + let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in + ELet ext (EVar ext (STAccum argty) (WSink .> wtapebinds @> idx)) $ + weakenExpr (autoWeak (#d (auto1 @(D2 t)) + &. #body (subList (bindingsBinds e0) subtape) + &. #ac (auto1 @(TAccum a)) + &. #tl (d2ace (select SAccum des))) + (#d :++: #body :++: #ac :++: #tl) + (#ac :++: #d :++: #body :++: #tl)) + -- Our contribution to the binding's cotangent _here_ is + -- zero, because we're contributing to an earlier binding + -- of the same value instead. + (EPair ext e2 (EZero ext argty)) + | let accumMap' = case argids of Just (VIArr i _) -> VarMap.insert i (STAccum argty) IZ (VarMap.sink1 accumMap) _ -> VarMap.sink1 accumMap diff --git a/src/Simplify.hs b/src/Simplify.hs index f5b7d15..cb835aa 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -100,12 +100,20 @@ simplify' = \case EMaybe _ e1 _ ENothing{} -> acted $ simplify' e1 EMaybe _ _ e1 (EJust _ e2) -> acted $ simplify' $ ELet ext e2 e1 - -- let floating to facilitate beta reduction + -- let floating EFst _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EFst ext body)) ESnd _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (ESnd ext body)) ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2))) EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body)) EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e))) + EAccum _ t p e1 (ELet _ rhs body) acc -> + acted $ simplify' $ + ELet ext rhs $ + EAccum ext t p (weakenExpr WSink e1) body (weakenExpr WSink acc) + + -- let () = e in () ~> e + ELet _ e1 (ENil _) | STNil <- typeOf e1 -> + acted $ simplify' e1 -- projection down-commuting EFst _ (ECase _ e1 e2 e3) -> -- cgit v1.2.3-70-g09d2