diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/CHAD.hs | 21 | ||||
| -rw-r--r-- | src/Simplify.hs | 10 | 
2 files changed, 29 insertions, 2 deletions
| diff --git a/src/CHAD.hs b/src/CHAD.hs index 8db0410..6ab4cfb 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -32,8 +32,9 @@ module CHAD (  ) where  import Data.Functor.Const +import Data.Some  import Data.Type.Bool (If) -import Data.Type.Equality (type (==)) +import Data.Type.Equality (type (==), testEquality)  import GHC.Stack (HasCallStack)  import Analysis.Identity (ValId(..), validSplitEither) @@ -1091,6 +1092,24 @@ drevScoped des accumMap argty argsto argids expr = case argsto of            SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (EZero ext argty))    SAccum +    | Just (VIArr i _) <- argids +    , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap +    , Just Refl <- testEquality foundTy (STAccum argty) +    , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) expr -> +        RetScoped e0 subtape e1 sub $ +          let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in +          ELet ext (EVar ext (STAccum argty) (WSink .> wtapebinds @> idx)) $ +            weakenExpr (autoWeak (#d (auto1 @(D2 t)) +                                    &. #body (subList (bindingsBinds e0) subtape) +                                    &. #ac (auto1 @(TAccum a)) +                                    &. #tl (d2ace (select SAccum des))) +                                   (#d :++: #body :++: #ac :++: #tl) +                                   (#ac :++: #d :++: #body :++: #tl)) +                       -- Our contribution to the binding's cotangent _here_ is +                       -- zero, because we're contributing to an earlier binding +                       -- of the same value instead. +                       (EPair ext e2 (EZero ext argty)) +      | let accumMap' = case argids of                          Just (VIArr i _) -> VarMap.insert i (STAccum argty) IZ (VarMap.sink1 accumMap)                          _ -> VarMap.sink1 accumMap diff --git a/src/Simplify.hs b/src/Simplify.hs index f5b7d15..cb835aa 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -100,12 +100,20 @@ simplify' = \case    EMaybe _ e1 _ ENothing{} -> acted $ simplify' e1    EMaybe _ _ e1 (EJust _ e2) -> acted $ simplify' $ ELet ext e2 e1 -  -- let floating to facilitate beta reduction +  -- let floating    EFst _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EFst ext body))    ESnd _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (ESnd ext body))    ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2)))    EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body))    EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e))) +  EAccum _ t p e1 (ELet _ rhs body) acc -> +    acted $ simplify' $ +      ELet ext rhs $ +        EAccum ext t p (weakenExpr WSink e1) body (weakenExpr WSink acc) + +  -- let () = e in ()  ~>  e +  ELet _ e1 (ENil _) | STNil <- typeOf e1 -> +    acted $ simplify' e1    -- projection down-commuting    EFst _ (ECase _ e1 e2 e3) -> | 
