diff options
Diffstat (limited to 'src/CHAD.hs')
| -rw-r--r-- | src/CHAD.hs | 1124 | 
1 files changed, 677 insertions, 447 deletions
| diff --git a/src/CHAD.hs b/src/CHAD.hs index ac308ac..cfae98d 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -3,6 +3,7 @@  {-# LANGUAGE GADTs #-}  {-# LANGUAGE LambdaCase #-}  {-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE ImpredicativeTypes #-}  {-# LANGUAGE OverloadedLabels #-}  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE QuantifiedConstraints #-} @@ -11,6 +12,7 @@  {-# LANGUAGE StandaloneDeriving #-}  {-# LANGUAGE StandaloneKindSignatures #-}  {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeData #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-}  {-# LANGUAGE UndecidableInstances #-} @@ -33,15 +35,14 @@ module CHAD (  import Data.Functor.Const  import Data.Some -import Data.Type.Bool (If)  import Data.Type.Equality (type (==), testEquality) -import GHC.Stack (HasCallStack)  import Analysis.Identity (ValId(..), validSplitEither)  import AST  import AST.Bindings  import AST.Count  import AST.Env +import AST.Sparse  import AST.Weaken.Auto  import CHAD.Accum  import CHAD.EnvDescr @@ -62,14 +63,20 @@ tapeTy :: SList STy binds -> STy (Tape binds)  tapeTy SNil = STNil  tapeTy (SCons t ts) = STPair t (tapeTy ts) -bindingsCollect :: Bindings f env binds -> Subenv binds tapebinds -                -> Append binds env :> env2 -> Ex env2 (Tape tapebinds) -bindingsCollect BTop SETop _ = ENil ext -bindingsCollect (BPush binds (t, _)) (SEYes sub) w = +bindingsCollectTape :: SList STy binds -> Subenv binds tapebinds +                    -> binds :> env2 -> Ex env2 (Tape tapebinds) +bindingsCollectTape SNil SETop _ = ENil ext +bindingsCollectTape (t `SCons` binds) (SEYesR sub) w =    EPair ext (EVar ext t (w @> IZ)) -            (bindingsCollect binds sub (w .> WSink)) -bindingsCollect (BPush binds _) (SENo sub) w = -  bindingsCollect binds sub (w .> WSink) +            (bindingsCollectTape binds sub (w .> WSink)) +bindingsCollectTape (_ `SCons` binds) (SENo sub) w = +  bindingsCollectTape binds sub (w .> WSink) + +-- bindingsCollectTape' :: forall f env binds tapebinds env2. Bindings f env binds -> Subenv binds tapebinds +--                      -> Append binds env :> env2 -> Ex env2 (Tape tapebinds) +-- bindingsCollectTape' binds sub w +--   | Refl <- lemAppendNil @binds +--   = bindingsCollectTape (bindingsBinds binds) sub (w .> wCopies @_ @_ @'[] (bindingsBinds binds) (WClosed @env))  -- In order from large to small: i.e. in reverse order from what we want,  -- because in a Bindings, the head of the list is the bottom-most entry. @@ -140,7 +147,7 @@ growRecon t ts (Reconstructor unfbs bs)        -- Add a 'fst' at the bottom of the builder stack.        -- First we have to weaken most of 'bs' to skip one more binding in the        -- unfolder stack above it. -      (BPush (fst (weakenBindings weakenExpr +      (BPush (fst (weakenBindingsE                        (wCopies (sappend (sreverse (stapeUnfoldings ts)) (SCons (tapeTy ts) SNil))                                 (WSink :: env :> (Tape (t : ts) : env))) bs))               (t @@ -190,14 +197,14 @@ buildReconstructor (SCons t ts) = growRecon t ts (buildReconstructor ts)  -- incidentally also add a bunch of additional bindings, namely 'Reverse  -- (TapeUnfoldings binds)', so the calling code just has to skip those in  -- whatever it wants to do. -reconstructBindings :: SList STy binds -> Idx env (Tape binds) -                    -> (Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds))) +reconstructBindings :: SList STy binds +                    -> (forall env. Idx env (Tape binds) -> Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds)))                         ,SList STy (Reverse (TapeUnfoldings binds))) -reconstructBindings binds tape = -  let Reconstructor unf build = buildReconstructor binds -  in (fst $ weakenBindings weakenExpr (WIdx tape) -             (bconcat (mapBindings fromUnfExpr unf) build) -     ,sreverse (stapeUnfoldings binds)) +reconstructBindings binds = +  (\tape -> let Reconstructor unf build = buildReconstructor binds +            in fst $ weakenBindingsE (WIdx tape) +                      (bconcat (mapBindings fromUnfExpr unf) build) +  ,sreverse (stapeUnfoldings binds))  ---------------------------------- DERIVATIVES --------------------------------- @@ -227,26 +234,37 @@ data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))  d2op :: SOp a t -> D2Op a t  d2op op = case op of -  OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EJust ext (EPair ext d d) +  OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EPair ext d d    OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> -    EJust ext (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) -                         (EOp ext (OMul t) (EPair ext (EFst ext e) d))) +    EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) +              (EOp ext (OMul t) (EPair ext (EFst ext e) d))    ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d -  OLt t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) -  OLe t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) -  OEq t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) +  OLt t -> Linear $ \_ -> pairZero t +  OLe t -> Linear $ \_ -> pairZero t +  OEq t -> Linear $ \_ -> pairZero t    ONot -> Linear $ \_ -> ENil ext -  OAnd -> Linear $ \_ -> ENothing ext (STPair STNil STNil) -  OOr -> Linear $ \_ -> ENothing ext (STPair STNil STNil) +  OAnd -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) +  OOr -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)    OIf -> Linear $ \_ -> ENil ext -  ORound64 -> Linear $ \_ -> EConst ext STF64 0.0 +  ORound64 -> Linear $ \_ -> EZero ext (SMTScal STF64) (ENil ext)    OToFl64 -> Linear $ \_ -> ENil ext    ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d)    OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d)    OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d) -  OIDiv t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil) -  OMod t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil) +  OIDiv t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext) +  OMod t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext)    where +    pairZero :: SScalTy a -> Ex env (D2 (TPair (TScal a) (TScal a))) +    pairZero t = ziNil t $ EPair ext (EZero ext (d2M (STScal t)) (ENil ext)) +                                     (EZero ext (d2M (STScal t)) (ENil ext)) +      where +        ziNil :: SScalTy a -> (ZeroInfo (D2s a) ~ TNil => r) -> r +        ziNil STI32 k = k +        ziNil STI64 k = k +        ziNil STF32 k = k +        ziNil STF64 k = k +        ziNil STBool k = k +      d2opUnArrangeInt :: SScalTy a                       -> (D2s a ~ TScal a => D2Op (TScal a) t)                       -> D2Op (TScal a) t @@ -261,11 +279,11 @@ d2op op = case op of                        -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t)                        -> D2Op (TPair (TScal a) (TScal a)) t      d2opBinArrangeInt ty float = case ty of -      STI32 -> Linear $ \_ -> ENothing ext (STPair STNil STNil) -      STI64 -> Linear $ \_ -> ENothing ext (STPair STNil STNil) +      STI32 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) +      STI64 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)        STF32 -> float        STF64 -> float -      STBool -> Linear $ \_ -> ENothing ext (STPair STNil STNil) +      STBool -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)      floatingD2 :: ScalIsFloating a ~ True                 => SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r @@ -293,7 +311,7 @@ conv1Idx (IS i) = IS (conv1Idx i)  data Idx2 env sto t    = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) -  | Idx2Me (Idx (Select env sto "merge") t) +  | Idx2Me (Idx (D2E (Select env sto "merge")) (D2 t))    | Idx2Di (Idx (Select env sto "discr") t)  conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t @@ -314,67 +332,158 @@ conv2Idx (DPush des (_, _, SDiscr)) (IS i) =                           Idx2Di j -> Idx2Di (IS j)  conv2Idx DTop i = case i of {} +opt2UnSparse :: SOp a b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b) +opt2UnSparse = go . opt2 +  where +    go :: STy b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b) +    go (STScal STI32) SpAbsent = \_ -> ENil ext +    go (STScal STI64) SpAbsent = \_ -> ENil ext +    go (STScal STF32) SpAbsent = \_ -> EZero ext (SMTScal STF32) (ENil ext) +    go (STScal STF64) SpAbsent = \_ -> EZero ext (SMTScal STF64) (ENil ext) +    go (STScal STBool) SpAbsent = \_ -> ENil ext +    go (STScal STF32) SpScal = id +    go (STScal STF64) SpScal = id +    go STNil _ = \_ -> ENil ext +    go (STPair t1 t2) (SpPair s1 s2) = \e -> eunPair e $ \_ e1 e2 -> EPair ext (go t1 s1 e1) (go t2 s2 e2) +    go t _ = error $ "Primitive operations that return " ++ show t ++ " are scary" ------------------------------------- MONOIDS ----------------------------------- - -zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0)) -zeroTup SNil = ENil ext -zeroTup (t `SCons` env) = EPair ext (zeroTup env) (ezeroD2 t) +----------------------------------- SPARSITY ----------------------------------- ------------------------------------- SUBENVS ----------------------------------- +expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a) +expandSparse t sp _ e | Just Refl <- isDense (d2M t) sp = e +expandSparse t (SpSparse sp) epr e = +  EMaybe ext +    (EZero ext (d2M t) (d2zeroInfo t epr)) +    (expandSparse t sp (weakenExpr WSink epr) (EVar ext (applySparse sp (d2 t)) IZ)) +    e +expandSparse t SpAbsent epr _ = EZero ext (d2M t) (d2zeroInfo t epr) +expandSparse (STPair t1 t2) (SpPair s1 s2) epr e = +  eunPair epr $ \w1 epr1 epr2 -> +  eunPair (weakenExpr w1 e) $ \w2 e1 e2 -> +    EPair ext (expandSparse t1 s1 (weakenExpr w2 epr1) e1) +              (expandSparse t2 s2 (weakenExpr w2 epr2) e2) +expandSparse (STEither t1 t2) (SpLEither s1 s2) epr e = +  ELCase ext e +    (EZero ext (d2M (STEither t1 t2)) (ENil ext)) +    (ECase ext (weakenExpr WSink epr) +       (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ)))) +       (EError ext (d2 (STEither t1 t2)) "expspa r<-dl")) +    (ECase ext (weakenExpr WSink epr) +       (EError ext (d2 (STEither t1 t2)) "expspa l<-dr") +       (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) +expandSparse (STLEither t1 t2) (SpLEither s1 s2) epr e = +  ELCase ext e +    (EZero ext (d2M (STEither t1 t2)) (ENil ext)) +    (ELCase ext (weakenExpr WSink epr) +       (EError ext (d2 (STEither t1 t2)) "expspa ln<-dl") +       (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ)))) +       (EError ext (d2 (STEither t1 t2)) "expspa lr<-dl")) +    (ELCase ext (weakenExpr WSink epr) +       (EError ext (d2 (STEither t1 t2)) "expspa ln<-dr") +       (EError ext (d2 (STEither t1 t2)) "expspa ll<-dr") +       (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) +expandSparse (STMaybe t) (SpMaybe s) epr e = +  EMaybe ext +    (ENothing ext (d2 t)) +    (let epr' = EMaybe ext (EError ext (d1 t) "expspa n<-dj") (EVar ext (d1 t) IZ) epr +     in EJust ext (expandSparse t s (weakenExpr WSink epr') (EVar ext (applySparse s (d2 t)) IZ))) +    e +expandSparse (STArr _ t) (SpArr s) epr e = +  ezipWith (expandSparse t s (EVar ext (d1 t) (IS IZ)) (EVar ext (applySparse s (d2 t)) IZ)) epr e +expandSparse (STScal STF32) SpScal _ e = e +expandSparse (STScal STF64) SpScal _ e = e +expandSparse (STAccum{}) _ _ _ = error "accumulators not allowed in source program" -subenvPlus :: SList STy env -           -> Subenv env env1 -> Subenv env env2 -           -> (forall env3. Subenv env env3 -                         -> Subenv env3 env1 -                         -> Subenv env3 env2 -                         -> (Ex exenv (Tup (D2E env1)) -                             -> Ex exenv (Tup (D2E env2)) -                             -> Ex exenv (Tup (D2E env3))) +subenvPlus :: SBool req1 -> SBool req2 +           -> SList SMTy env +           -> SubenvS env env1 -> SubenvS env env2 +           -> (forall env3. SubenvS env env3 +                         -> Injection req1 (Tup env1) (Tup env3) +                         -> Injection req2 (Tup env2) (Tup env3) +                         -> (forall e. Ex e (Tup env1) -> Ex e (Tup env2) -> Ex e (Tup env3))                           -> r)             -> r -subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext) -subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k = -  subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> +-- don't destroy effects! +subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\a b -> use a $ use b $ ENil ext) + +subenvPlus req1 req2 (SCons _ env) (SENo sub1) (SENo sub2) k = +  subenvPlus req1 req2 env sub1 sub2 $ \sub3 s31 s32 pl ->      k (SENo sub3) s31 s32 pl -subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k = -  subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> -    k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 -> -      ELet ext e1 $ -        EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) -                      (weakenExpr WSink e2)) -                  (ESnd ext (EVar ext (typeOf e1) IZ)) -subenvPlus (SCons _ env) (SENo sub1) (SEYes sub2) k = -  subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> -    k (SEYes sub3) (SENo s31) (SEYes s32) $ \e1 e2 -> -      ELet ext e2 $ -        EPair ext (pl (weakenExpr WSink e1) -                      (EFst ext (EVar ext (typeOf e2) IZ))) -                  (ESnd ext (EVar ext (typeOf e2) IZ)) -subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k = -  subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> -    k (SEYes sub3) (SEYes s31) (SEYes s32) $ \e1 e2 -> -      ELet ext e1 $ -      ELet ext (weakenExpr WSink e2) $ -        EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) -                      (EFst ext (EVar ext (typeOf e2) IZ))) -                  (EPlus ext (d2M t) -                    (ESnd ext (EVar ext (typeOf e1) (IS IZ))) -                    (ESnd ext (EVar ext (typeOf e2) IZ))) -expandSubenvZeros :: SList STy env0 -> Subenv env0 env0Merge -> Ex env (Tup (D2E env0Merge)) -> Ex env (Tup (D2E env0)) -expandSubenvZeros _ SETop _ = ENil ext -expandSubenvZeros (SCons t ts) (SEYes sub) e = -  ELet ext e $ -    let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ -    in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var) -expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (ezeroD2 t) +subenvPlus req1 SF (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k = +  subenvPlus req1 SF env sub1 sub2 $ \sub3 minj13 _ pl -> +    k (SEYes sp1 sub3) +      (withInj minj13 $ \inj13 -> +        \e1 -> eunPair e1 $ \_ e1a e1b -> +          EPair ext (inj13 e1a) e1b) +      Noinj +      (\e1 e2 -> +        ELet ext e1 $ +          EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) +                        (weakenExpr WSink e2)) +                    (ESnd ext (EVar ext (typeOf e1) IZ))) +subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k +  | Just zero1 <- cheapZero (applySparse sp1 t) = +      subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl -> +        k (SEYes sp1 sub3) +          (withInj minj13 $ \inj13 -> +            \e1 -> eunPair e1 $ \_ e1a e1b -> +              EPair ext (inj13 e1a) e1b) +          (Inj $ \e2 -> EPair ext (inj23 e2) zero1) +          (\e1 e2 -> +            ELet ext e1 $ +              EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) +                            (weakenExpr WSink e2)) +                        (ESnd ext (EVar ext (typeOf e1) IZ))) +  | otherwise = +      subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl -> +        k (SEYes (SpSparse sp1) sub3) +          (withInj minj13 $ \inj13 -> +            \e1 -> eunPair e1 $ \_ e1a e1b -> +              EPair ext (inj13 e1a) (EJust ext e1b)) +          (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t)))) +          (\e1 e2 -> +            ELet ext e1 $ +              EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) +                            (weakenExpr WSink e2)) +                        (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ)))) + +subenvPlus req1 req2 (SCons t env) sub1@SENo{} sub2@SEYes{} k = +  subenvPlus req2 req1 (SCons t env) sub2 sub1 $ \sub3 minj23 minj13 pl -> +    k sub3 minj13 minj23 (flip pl) + +subenvPlus req1 req2 (SCons t env) (SEYes sp1 sub1) (SEYes sp2 sub2) k = +  subenvPlus req1 req2 env sub1 sub2 $ \sub3 minj13 minj23 pl -> +  sparsePlusS req1 req2 t sp1 sp2 $ \sp3 mTinj13 mTinj23 plus -> +    k (SEYes sp3 sub3) +      (withInj2 minj13 mTinj13 $ \inj13 tinj13 -> +        \e1 -> eunPair e1 $ \_ e1a e1b -> +          EPair ext (inj13 e1a) (tinj13 e1b)) +      (withInj2 minj23 mTinj23 $ \inj23 tinj23 -> +        \e2 -> eunPair e2 $ \_ e2a e2b -> +          EPair ext (inj23 e2a) (tinj23 e2b)) +      (\e1 e2 -> +        ELet ext e1 $ +        ELet ext (weakenExpr WSink e2) $ +          EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) +                        (EFst ext (EVar ext (typeOf e2) IZ))) +                    (plus +                      (ESnd ext (EVar ext (typeOf e1) (IS IZ))) +                      (ESnd ext (EVar ext (typeOf e2) IZ)))) -assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[] -assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl -assertSubenvEmpty SETop = Refl -assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" +expandSubenvZeros :: D1E env0 :> env -> SList STy env0 -> SubenvS (D2E env0) contribs +                  -> Ex env (Tup contribs) -> Ex env (Tup (D2E env0)) +expandSubenvZeros _ SNil SETop _ = ENil ext +expandSubenvZeros w (SCons t ts) (SEYes sp sub) e = +  eunPair e $ \w1 e1 e2 -> +    EPair ext +      (expandSubenvZeros (w1 .> WPop w) ts sub e1) +      (expandSparse t sp (EVar ext (d1 t) (w1 .> w @> IZ)) e2) +expandSubenvZeros w (SCons t ts) (SENo sub) e = +  EPair ext +    (expandSubenvZeros (WPop w) ts sub e) +    (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))  --------------------------------- ACCUMULATORS --------------------------------- @@ -407,8 +516,8 @@ accumPromote :: forall dt env sto proxy r.                        -- accumulators.                   -> (forall shbinds.                              SList STy shbinds -                         -> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) -                            :> Append (D2AcE envPro) (D2 dt : Append shbinds (D2AcE (Select env sto "accum")))) +                         -> (dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) +                            :> Append (D2AcE envPro) (dt : Append shbinds (D2AcE (Select env sto "accum"))))                        -- ^ A weakening that converts a computation in the                        -- revised environment to one in the original environment                        -- extended with some accumulators. @@ -422,14 +531,14 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of        k (storepl `DPush` (t, vid, SAccum))          envpro          prosub -        (SEYes accrevsub) +        (SEYesR accrevsub)          (VarMap.sink1 accumMap)          (\shbinds -> -          autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr))) +          autoWeak (#pro (d2ace envpro) &. #d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr)))                     (#acc :++: (#pro :++: #d :++: #shb :++: #tl))                     (#pro :++: #d :++: #shb :++: #acc :++: #tl)            .> WCopy (wf shbinds) -          .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl))) +          .> autoWeak (#d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl)))                        (#d :++: #shb :++: #acc :++: #tl)                        (#acc :++: (#d :++: #shb :++: #tl))) @@ -449,7 +558,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of        accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf ->          k (storepl `DPush` (t, vid, SAccum))            (t `SCons` envpro) -          (SEYes prosub) +          (SEYesR prosub)            (SENo accrevsub)            (let accumMap' = VarMap.sink1 accumMap             in case fromArrayValId vid of @@ -466,7 +575,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of              -- goal:                        |                                                                 ARE EQUAL  ||              --   D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum"))  :>  TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))              WCopy (wf shbinds) -            .> WPick @(TAccum (D2 t)) @(D2 dt : shbinds) (Const () `SCons` shbindsC) +            .> WPick @(TAccum (D2 t)) @(dt : shbinds) (Const () `SCons` shbindsC)                   (WId @(D2AcE (Select env1 stoRepl "accum"))))    -- Discrete values are left as-is, nothing to do @@ -484,6 +593,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of        STNil -> True        STPair a b -> isDiscrete a && isDiscrete b        STEither a b -> isDiscrete a && isDiscrete b +      STLEither a b -> isDiscrete a && isDiscrete b        STMaybe a -> isDiscrete a        STArr _ a -> isDiscrete a        STScal st -> case st of @@ -493,26 +603,45 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of          STF64 -> False          STBool -> True        STAccum{} -> False -      STLEither a b -> isDiscrete a && isDiscrete b  ---------------------------- RETURN TRIPLE FROM CHAD --------------------------- -data Ret env0 sto t = -  forall shbinds tapebinds env0Merge. +data Ret env0 sto sd t = +  forall shbinds tapebinds contribs.      Ret (Bindings Ex (D1E env0) shbinds)  -- shared binds          (Subenv shbinds tapebinds)          (Ex (Append shbinds (D1E env0)) (D1 t)) -        (Subenv (Select env0 sto "merge") env0Merge) -        (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) -deriving instance Show (Ret env0 sto t) +        (SubenvS (D2E (Select env0 sto "merge")) contribs) +        (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)) +deriving instance Show (Ret env0 sto sd t) -data RetPair env0 sto env shbinds tapebinds t = -  forall env0Merge. -    RetPair (Ex (Append shbinds env) (D1 t)) -            (Subenv (Select env0 sto "merge") env0Merge) -            (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) -deriving instance Show (RetPair env0 sto env shbinds tapebinds t) +type data TyTyPair = MkTyTyPair Ty Ty + +data SingleRet env0 sto (pair :: TyTyPair) = +  forall shbinds tapebinds. +    SingleRet +      (Bindings Ex (D1E env0) shbinds)  -- shared binds +      (Subenv shbinds tapebinds) +      (RetPair env0 sto (D1E env0) shbinds tapebinds pair) + +-- pattern Ret1 :: forall env0 sto Bindings Ex (D1E env0) shbinds +--              -> Subenv shbinds tapebinds +--              -> Ex (Append shbinds (D1E env0)) (D1 t) +--              -> SubenvS (D2E (Select env0 sto "merge")) contribs +--              -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs) +--              -> SingleRet env0 sto (MkTyTyPair sd t) +-- pattern Ret1 e0 subtape e1 sub e2 = SingleRet e0 subtape (RetPair e1 sub e2) +-- {-# COMPLETE Ret1 #-} + +data RetPair env0 sto env shbinds tapebinds (pair :: TyTyPair) where +  RetPair :: forall sd t contribs  -- existentials +                    env0 sto env shbinds tapebinds.  -- universals +             Ex (Append shbinds env) (D1 t) +          -> SubenvS (D2E (Select env0 sto "merge")) contribs +          -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs) +          -> RetPair env0 sto env shbinds tapebinds (MkTyTyPair sd t) +deriving instance Show (RetPair env0 sto env shbinds tapebinds pair)  data Rets env0 sto env list =    forall shbinds tapebinds. @@ -521,113 +650,149 @@ data Rets env0 sto env list =           (SList (RetPair env0 sto env shbinds tapebinds) list)  deriving instance Show (Rets env0 sto env list) +toSingleRet :: Ret env0 sto sd t -> SingleRet env0 sto (MkTyTyPair sd t) +toSingleRet (Ret e0 subtape e1 sub e2) = SingleRet e0 subtape (RetPair e1 sub e2) +  weakenRetPair :: SList STy shbinds -> env :> env' -              -> RetPair env0 sto env shbinds tapebinds t -> RetPair env0 sto env' shbinds tapebinds t +              -> RetPair env0 sto env shbinds tapebinds pair -> RetPair env0 sto env' shbinds tapebinds pair  weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2  weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list  weakenRets w (Rets binds tapesub list) = -  let (binds', _) = weakenBindings weakenExpr w binds +  let (binds', _) = weakenBindingsE w binds    in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list) -rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto t f. +rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto pair f.                   Descr env0 sto                -> SList f b1 -> SList f b2                -> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2 -              -> RetPair env0 sto (Append b1 env) b2 tapebinds2 t -              -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) t -rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair p sub d) +              -> RetPair env0 sto (Append b1 env) b2 tapebinds2 pair +              -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) pair +rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair @sd e1 sub e2)    | Refl <- lemAppendAssoc @b2 @b1 @env = -      RetPair p sub (weakenExpr (autoWeak -                                  (#d (auto1 @(D2 t)) -                                   &. #t2 (subList b2 subtape2) -                                   &. #t1 (subList b1 subtape1) -                                   &. #tl (d2ace (select SAccum descr))) -                                  (#d :++: (#t2 :++: #tl)) -                                  (#d :++: ((#t2 :++: #t1) :++: #tl))) -                                d) +      RetPair e1 sub +              (weakenExpr (autoWeak +                            (#d (auto1 @sd) +                             &. #t2 (subList b2 subtape2) +                             &. #t1 (subList b1 subtape1) +                             &. #tl (d2ace (select SAccum descr))) +                            (#d :++: (#t2 :++: #tl)) +                            (#d :++: ((#t2 :++: #t1) :++: #tl))) +                e2) -retConcat :: forall env0 sto list. Descr env0 sto -> SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list +retConcat :: forall env0 sto list. Descr env0 sto -> SList (SingleRet env0 sto) list -> Rets env0 sto (D1E env0) list  retConcat _ SNil = Rets BTop SETop SNil -retConcat descr (SCons (Ret (b :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) p sub d) list) +retConcat descr (SCons (SingleRet (e0 :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) (RetPair e1 sub e2)) list)    | Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs -      <- weakenRets (sinkWithBindings b) (retConcat descr list) +      <- weakenRets (sinkWithBindings e0) (retConcat descr list)    , Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0)    , Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum")) -  = Rets (bconcat b binds) +  = Rets (bconcat e0 binds)           (subenvConcat subtape subtape2) -         (SCons (RetPair (weakenExpr (sinkWithBindings binds) p) +         (SCons (RetPair (weakenExpr (sinkWithBindings binds) e1)                           sub -                         (weakenExpr (WCopy (sinkWithSubenv subtape2)) d)) -                (slistMap (rebaseRetPair descr (bindingsBinds b) (bindingsBinds binds) +                         (weakenExpr (WCopy (sinkWithSubenv subtape2)) e2)) +                (slistMap (rebaseRetPair descr (bindingsBinds e0) (bindingsBinds binds)                                                 subtape subtape2)                            pairs))  freezeRet :: Descr env sto -          -> Ret env sto t +          -> Ret env sto (D2 t) t            -> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge")))) -freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) = -  let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0 +freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ _ t) = +  let (e0', wInsertD2Ac) = weakenBindingsE (WSink .> wSinks (d2ace (select SAccum descr))) e0        e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2 +      tContribs = tTup (slistMap fromSMTy (subList (d2eM (select SMerge descr)) sub)) +      library = #d (auto1 @(D2 t)) +                &. #tape (subList (bindingsBinds e0) subtape) +                &. #shbinds (bindingsBinds e0) +                &. #d2ace (d2ace (select SAccum descr)) +                &. #tl (desD1E descr) +                &. #contribs (SCons tContribs SNil)    in letBinds e0' $         EPair ext           (weakenExpr wInsertD2Ac e1) -         (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t)) -                                          &. #tape (subList (bindingsBinds e0) subtape) -                                          &. #shbinds (bindingsBinds e0) -                                          &. #d2ace (d2ace (select SAccum descr)) -                                          &. #tl (desD1E descr)) +         (ELet ext (weakenExpr (autoWeak library                                           (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl)                                           (#shbinds :++: #d :++: #d2ace :++: #tl))                        e2') $ -          expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ)) +          expandSubenvZeros +            (autoWeak library #tl (#contribs :++: #shbinds :++: #d :++: #d2ace :++: #tl) +             .> wUndoSubenv (subenvD1E (selectSub SMerge descr))) +            (select SMerge descr) sub (EVar ext tContribs IZ))  ---------------------------- THE CHAD TRANSFORMATION --------------------------- -drev :: forall env sto t. +drev :: forall env sto sd t.          (?config :: CHADConfig)       => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) -     -> Expr ValId env t -> Ret env sto t -drev des accumMap = \case +     -> Sparse (D2 t) sd +     -> Expr ValId env t -> Ret env sto sd t +drev des _ sd | isAbsent sd = +  \e -> +    Ret BTop +        SETop +        (drevPrimal des e) +        (subenvNone (d2e (select SMerge des))) +        (ENil ext) +drev _ _ SpAbsent = error "Absent should be isAbsent" + +drev des accumMap (SpSparse sd) = +  \e -> +    case drev des accumMap sd e of { Ret e0 subtape e1 sub e2 -> +    subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> +      Ret e0 +          subtape +          e1 +          sub' +          (emaybe (EVar ext (STMaybe (applySparse sd (d2 (typeOf e)))) IZ) +            (inj2 (ENil ext)) +            (inj1 (weakenExpr (WCopy WSink) e2))) +    } + +drev des accumMap sd = \case    EVar _ t i ->      case conv2Idx des i of        Idx2Ac accI ->          Ret BTop              SETop              (EVar ext (d1 t) (conv1Idx i)) -            (subenvNone (select SMerge des)) -            (EAccum ext (d2M t) SAPHere (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) +            (subenvNone (d2e (select SMerge des))) +            (let ty = applySparse sd (d2M t) +             in EAccum ext (d2M t) SAPHere (ENil ext) sd (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI)))        Idx2Me tupI ->          Ret BTop              SETop              (EVar ext (d1 t) (conv1Idx i)) -            (subenvOnehot (select SMerge des) tupI) -            (EPair ext (ENil ext) (EVar ext (d2 t) IZ)) +            (subenvOnehot (d2e (select SMerge des)) tupI sd) +            (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t)) IZ))        Idx2Di _ ->          Ret BTop              SETop              (EVar ext (d1 t) (conv1Idx i)) -            (subenvNone (select SMerge des)) +            (subenvNone (d2e (select SMerge des)))              (ENil ext)    ELet _ (rhs :: Expr _ _ a) body -    | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des accumMap rhs -    , ChosenStorage storage <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge -    , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) body -    , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 -    , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) -    , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> -    subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> -    let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in -    Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') -        (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody) +    | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && typeHasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge +    , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body +    , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) rhs1 subRHS rhs2 <- drev des accumMap sdBody rhs +    , let (body0', wbody0') = weakenBindingsE (WCopy (sinkWithBindings rhs0)) body0 +    , Refl <- lemAppendAssoc @body_shbinds @'[D1 a] @rhs_shbinds +    , Refl <- lemAppendAssoc @body_shbinds @(D1 a : rhs_shbinds) @(D1E env) +    , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) +    -> +    subenvPlus SF SF (d2eM (select SMerge des)) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> +    let bodyResType = STPair (contribTupTy des subBody) (applySparse sdBody (d2 (typeOf rhs))) in +    Ret (bconcat (rhs0 `bpush` rhs1) body0') +        (subenvConcat subtapeRHS subtapeBody)          (weakenExpr wbody0' body1)          subBoth -        (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t)) -                                         &. #body (subList (bindingsBinds body0) subtapeBody) +        (ELet ext (weakenExpr (autoWeak (#d (auto1 @sd) +                                         &. #body (subList (bindingsBinds body0 `sappend` SCons (d1 (typeOf rhs)) SNil) subtapeBody)                                           &. #rhs (subList (bindingsBinds rhs0) subtapeRHS)                                           &. #tl (d2ace (select SAccum des)))                                          (#d :++: #body :++: #tl) @@ -637,317 +802,359 @@ drev des accumMap = \case             (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $               weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $           plus_RHS_Body -           (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ) +           (EVar ext (contribTupTy des subRHS) IZ)             (EFst ext (EVar ext bodyResType (IS IZ))))    EPair _ a b -    | Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) -        <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil -    , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) -> -    subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B -> +    | SpPair sd1 sd2 <- sd +    , Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) +        <- retConcat des $ toSingleRet (drev des accumMap sd1 a) `SCons` toSingleRet (drev des accumMap sd2 b) `SCons` SNil +    , let dt = STPair (applySparse sd1 (d2 (typeOf a))) (applySparse sd2 (d2 (typeOf b))) -> +    subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B ->      Ret binds          subtape          (EPair ext a1 b1)          subBoth -        (EMaybe ext -           (zeroTup (subList (select SMerge des) subBoth)) -           (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) -                      (weakenExpr (WCopy (wSinks' @[_,_])) a2)) $ -            ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ))) -                      (weakenExpr (WCopy (wSinks' @[_,_,_])) b2)) $ -            plus_A_B -             (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ)) -             (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ)) -           (EVar ext (STMaybe (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ)) +        (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) +                   (weakenExpr (WCopy WSink) a2)) $ +         ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ))) +                   (weakenExpr (WCopy (WSink .> WSink)) b2)) $ +         plus_A_B +           (EVar ext (contribTupTy des subA) (IS IZ)) +           (EVar ext (contribTupTy des subB) IZ))    EFst _ e -    | Ret e0 subtape e1 sub e2 <- drev des accumMap e -    , STPair t1 t2 <- typeOf e -> +    | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair sd SpAbsent) e +    , STPair t1 _ <- typeOf e ->      Ret e0          subtape          (EFst ext e1)          sub -        (ELet ext (EJust ext (EPair ext (EVar ext (d2 t1) IZ) (ezeroD2 t2))) $ +        (ELet ext (EPair ext (EVar ext (applySparse sd (d2 t1)) IZ) (ENil ext)) $             weakenExpr (WCopy WSink) e2)    ESnd _ e -    | Ret e0 subtape e1 sub e2 <- drev des accumMap e -    , STPair t1 t2 <- typeOf e -> +    | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair SpAbsent sd) e +    , STPair _ t2 <- typeOf e ->      Ret e0          subtape          (ESnd ext e1)          sub -        (ELet ext (EJust ext (EPair ext (ezeroD2 t1) (EVar ext (d2 t2) IZ))) $ +        (ELet ext (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t2)) IZ)) $             weakenExpr (WCopy WSink) e2) -  ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext) +  -- Don't need to handle ENil, because its cotangent is always absent! +  -- ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (d2e (select SMerge des))) (ENil ext)    EInl _ t2 e -    | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> +    | SpLEither sd1 sd2 <- sd +    , Ret e0 subtape e1 sub e2 <- drev des accumMap sd1 e -> +    subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ ->      Ret e0          subtape          (EInl ext (d1 t2) e1) -        sub +        sub'          (ELCase ext -           (EVar ext (STLEither (d2 (typeOf e)) (d2 t2)) IZ) -           (zeroTup (subList (select SMerge des) sub)) -              (weakenExpr (WCopy WSink) e2) -              (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr")) +           (EVar ext (STLEither (applySparse sd1 (d2 (typeOf e))) (applySparse sd2 (d2 t2))) IZ) +           (inj2 $ ENil ext) +           (inj1 $ weakenExpr (WCopy WSink) e2) +           (EError ext (contribTupTy des sub') "inl<-dinr"))    EInr _ t1 e -    | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> +    | SpLEither sd1 sd2 <- sd +    , Ret e0 subtape e1 sub e2 <- drev des accumMap sd2 e -> +    subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ ->      Ret e0          subtape          (EInr ext (d1 t1) e1) -        sub +        sub'          (ELCase ext -           (EVar ext (STLEither (d2 t1) (d2 (typeOf e))) IZ) -           (zeroTup (subList (select SMerge des) sub)) -           (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl") -           (weakenExpr (WCopy WSink) e2)) +           (EVar ext (STLEither (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 (typeOf e)))) IZ) +           (inj2 $ ENil ext) +           (EError ext (contribTupTy des sub') "inr<-dinl") +           (inj1 $ weakenExpr (WCopy WSink) e2))    ECase _ e (a :: Expr _ _ t) b -    | STEither t1 t2 <- typeOf e -    , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap e -    , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge -    , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge +    | STEither (t1 :: STy a) (t2 :: STy b) <- typeOf e +    , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && typeHasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge +    , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && typeHasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge      , let (bindids1, bindids2) = validSplitEither (extOf e) -    , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des accumMap t1 storage1 bindids1 a -    , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des accumMap t2 storage2 bindids2 b +    , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA sd1 a2 +          <- drevScoped des accumMap t1 storage1 bindids1 sd a +    , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB sd2 b2 +          <- drevScoped des accumMap t2 storage2 bindids2 sd b +    , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap (SpLEither sd1 sd2) e      , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))      , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) -    , let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA) -    , let tapeB = tapeTy (subList (bindingsBinds b0) subtapeB) -    , let collectA = bindingsCollect a0 subtapeA -    , let collectB = bindingsCollect b0 subtapeB +    , let subtapeListA = subList (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA +    , let subtapeListB = subList (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB +    , let tapeA = tapeTy subtapeListA +    , let tapeB = tapeTy subtapeListB +    , let collectA = bindingsCollectTape @_ @_ @(Append rhs_a_binds (D1 a : Append e_binds (D1E env))) +                                         (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA +    , let collectB = bindingsCollectTape @_ @_ @(Append rhs_b_binds (D1 b : Append e_binds (D1E env))) +                                         (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB      , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB) -    , let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0 -    , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0 +    , let (a0', wa0') = weakenBindingsE (WCopy (sinkWithBindings e0)) a0 +    , let (b0', wb0') = weakenBindingsE (WCopy (sinkWithBindings e0)) b0 +    , Refl <- lemAppendNil @(Append rhs_a_binds '[D1 a]) +    , Refl <- lemAppendNil @(Append rhs_b_binds '[D1 b]) +    , Refl <- lemAppendAssoc @rhs_a_binds @'[D1 a] @(D1E env) +    , Refl <- lemAppendAssoc @rhs_b_binds @'[D1 b] @(D1E env) +    , let wa0'' = wa0' .> wCopies (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) (WClosed @(D1E env)) +    , let wb0'' = wb0' .> wCopies (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) (WClosed @(D1E env))      -> -    subenvPlus (select SMerge des) subA subB $ \subAB sAB_A sAB_B _ -> -    subenvPlus (select SMerge des) subAB subE $ \subOut _ _ plus_AB_E -> -    let tCaseRet = STPair (tTup (d2e (subList (select SMerge des) subAB))) (STLEither (d2 t1) (d2 t2)) in -    Ret (e0 `BPush` -         (tPrimal, -            ECase ext e1 -              (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0')))) -              (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0')))))) -        (SEYes subtapeE) +    subenvPlus ST ST (d2eM (select SMerge des)) subA subB $ \subAB (Inj sAB_A) (Inj sAB_B) _ -> +    subenvPlus SF SF (d2eM (select SMerge des)) subAB subE $ \subOut _ _ plus_AB_E -> +    Ret (e0 `bpush` ECase ext e1 +                (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0'')))) +                (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0''))))) +        (SEYesR subtapeE)          (EFst ext (EVar ext tPrimal IZ))          subOut -        (ELet ext +        (elet             (ECase ext (ESnd ext (EVar ext tPrimal (IS IZ))) -              (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds a0) subtapeA) IZ -               in letBinds rebinds $ +              (let (rebinds, prerebinds) = reconstructBindings subtapeListA +               in letBinds (rebinds IZ) $                      ELet ext -                      (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_tape : D2 t : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend (subList (bindingsBinds a0) subtapeA) prerebinds) @> IS IZ)) $ -                    ELet ext -                      (weakenExpr (autoWeak (#d (auto1 @(D2 t)) -                                             &. #ta0 (subList (bindingsBinds a0) subtapeA) +                      (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_a_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListA prerebinds) @> IS IZ)) $ +                    elet +                      (weakenExpr (autoWeak (#d (auto1 @sd) +                                             &. #ta0 subtapeListA                                               &. #prea0 prerebinds -                                             &. #recon (tapeA `SCons` d2 (typeOf a) `SCons` SNil) +                                             &. #recon (tapeA `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil)                                               &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE)                                               &. #tl (d2ace (select SAccum des)))                                              (#d :++: #ta0 :++: #tl)                                              (#d :++: (#ta0 :++: #prea0) :++: #recon :++: #binds :++: #tl))                                    a2) $ -                    EPair ext -                     (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $ -                        EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ)) -                     (ELInl ext (d2 t2) -                       (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ)))) -              (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds b0) subtapeB) IZ -               in letBinds rebinds $ -                    ELet ext -                      (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_tape : D2 t : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend (subList (bindingsBinds b0) subtapeB) prerebinds) @> IS IZ)) $ +                    EPair ext (sAB_A $ EFst ext (evar IZ)) +                              (ELInl ext (applySparse sd2 (d2 t2)) (ESnd ext (evar IZ)))) +              (let (rebinds, prerebinds) = reconstructBindings subtapeListB +               in letBinds (rebinds IZ) $                      ELet ext -                      (weakenExpr (autoWeak (#d (auto1 @(D2 t)) -                                             &. #tb0 (subList (bindingsBinds b0) subtapeB) +                      (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_b_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListB prerebinds) @> IS IZ)) $ +                    elet +                      (weakenExpr (autoWeak (#d (auto1 @sd) +                                             &. #tb0 subtapeListB                                               &. #preb0 prerebinds -                                             &. #recon (tapeB `SCons` d2 (typeOf a) `SCons` SNil) +                                             &. #recon (tapeB `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil)                                               &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE)                                               &. #tl (d2ace (select SAccum des)))                                              (#d :++: #tb0 :++: #tl)                                              (#d :++: (#tb0 :++: #preb0) :++: #recon :++: #binds :++: #tl))                                    b2) $ -                    EPair ext -                      (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $ -                         EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ)) -                      (ELInr ext (d2 t1) -                        (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))))) $ -         ELet ext -           (ELet ext (ESnd ext (EVar ext tCaseRet IZ)) $ -              weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $ +                    EPair ext (sAB_B $ EFst ext (evar IZ)) +                              (ELInr ext (applySparse sd1 (d2 t1)) (ESnd ext (evar IZ))))) $           plus_AB_E -           (EFst ext (EVar ext tCaseRet (IS IZ))) -           (EVar ext (tTup (d2e (subList (select SMerge des) subE))) IZ)) +           (EFst ext (evar IZ)) +           (ELet ext (ESnd ext (evar IZ)) $ +              weakenExpr (WCopy (wSinks' @[_,_,_])) e2))    EConst _ t val ->      Ret BTop          SETop          (EConst ext t val) -        (subenvNone (select SMerge des)) +        (subenvNone (d2e (select SMerge des)))          (ENil ext)    EOp _ op e -    | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> +    | Ret e0 subtape e1 sub e2 <- drev des accumMap (spDense (d2M (opt1 op))) e ->      case d2op op of        Linear d2opfun ->          Ret e0              subtape              (d1op op e1)              sub -            (ELet ext (d2opfun (EVar ext (d2 (opt2 op)) IZ)) +            (ELet ext (d2opfun (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ)))                 (weakenExpr (WCopy WSink) e2))        Nonlinear d2opfun -> -        Ret (e0 `BPush` (d1 (typeOf e), e1)) -            (SEYes subtape) +        Ret (e0 `bpush` e1) +            (SEYesR subtape)              (d1op op $ EVar ext (d1 (typeOf e)) IZ)              sub              (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ)) -                               (EVar ext (d2 (opt2 op)) IZ)) +                               (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ)))                 (weakenExpr (WCopy (wSinks' @[_,_])) e2)) -  ECustom _ _ _ storety _ pr du a b +  ECustom _ _ tb _ srce pr du a b      -- allowed to ignore a2 because 'a' is the part of the input that is inactive -    | Rets binds subtape (RetPair a1 _ _ `SCons` RetPair b1 bsub b2 `SCons` SNil) -        <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil -> -    Ret (binds `BPush` (typeOf a1, a1) -               `BPush` (typeOf b1, weakenExpr WSink b1) -               `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) -               `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ))) -        (SEYes (SENo (SENo (SENo subtape)))) -        (EFst ext (EVar ext (typeOf pr) (IS IZ))) -        bsub -        (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ -           weakenExpr (WCopy (WSink .> WSink)) b2) +    | Ret b0 bsubtape b1 bsub b2 <- drev des accumMap (spDense (d2M tb)) b -> +    case isDense (d2M (typeOf srce)) sd of +      Just Refl -> +        Ret (b0 `bpush` weakenExpr (sinkWithBindings b0) (drevPrimal des a) +                `bpush` weakenExpr WSink b1 +                `bpush` weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr) +                `bpush` ESnd ext (EVar ext (typeOf pr) IZ)) +            (SEYesR (SENo (SENo (SENo bsubtape)))) +            (EFst ext (EVar ext (typeOf pr) (IS IZ))) +            bsub +            (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ +               weakenExpr (WCopy (WSink .> WSink)) b2) + +      Nothing -> +        Ret (b0 `bpush` weakenExpr (sinkWithBindings b0) (drevPrimal des a) +                `bpush` weakenExpr WSink b1 +                `bpush` weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) +            (SEYesR (SENo (SENo bsubtape))) +            (EFst ext (EVar ext (typeOf pr) IZ)) +            bsub +            (ELet ext (ESnd ext (EVar ext (typeOf pr) (IS IZ))) $  -- tape +             ELet ext (expandSparse (typeOf srce) sd  -- expanded incoming cotangent +                                    (EFst ext (EVar ext (typeOf pr) (IS (IS IZ)))) +                                    (EVar ext (applySparse sd (d2 (typeOf srce))) (IS IZ))) $ +             ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ +               weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) b2) + +  ERecompute _ e -> +    deleteUnused (descrList des) (occCountAll e) $ \usedSub -> +    let smallE = unsafeWeakenWithSubenv usedSub e in +    subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> +    case drev usedDes (VarMap.subMap subAccumUsed accumMap) sd smallE of { Ret e0 subtape _ sub e2 -> +    let subMergeUsed' = subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E subMergeUsed) in +    Ret (collectBindings (desD1E des) subD1eUsed) +        (subenvAll (desD1E usedDes)) +        (weakenExpr (wSinks (desD1E usedDes)) $ drevPrimal des e) +        (subenvCompose subMergeUsed' sub) +        (letBinds (fst (weakenBindingsE (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $ +           weakenExpr +             (autoWeak (#d (auto1 @sd) +                        &. #shbinds (bindingsBinds e0) +                        &. #tape (subList (bindingsBinds e0) subtape) +                        &. #d1env (desD1E usedDes) +                        &. #tl' (d2ace (select SAccum usedDes)) +                        &. #tl (d2ace (select SAccum des))) +                       (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: LPreW #tl' #tl (wUndoSubenv subAccumUsed)) +                       (#shbinds :++: #d :++: #d1env :++: #tl)) +             e2) +    }    EError _ t s ->      Ret BTop          SETop          (EError ext (d1 t) s) -        (subenvNone (select SMerge des)) +        (subenvNone (d2e (select SMerge des)))          (ENil ext)    EConstArr _ n t val ->      Ret BTop          SETop          (EConstArr ext n t val) -        (subenvNone (select SMerge des)) +        (subenvNone (d2e (select SMerge des)))          (ENil ext)    EBuild _ (ndim :: SNat ndim) she (orige :: Expr _ _ eltty) -    | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des accumMap she  -- allowed to ignore she2 here because she has a discrete result +    | SpArr @_ @sdElt sdElt <- sd      , let eltty = typeOf orige      , shty :: STy shty <- tTup (sreplicate ndim tIx)      , Refl <- indexTupD1Id ndim -> -    deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') -> -    let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in -    subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> -    accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> +    deleteUnused (descrList des) (occEnvPopSome (occCountAll orige)) $ \(usedSub :: Subenv env env') -> +    let e = unsafeWeakenWithSubenv (SEYesR usedSub) orige in +    subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> +    accumPromote sdElt usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->      let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in -    case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> -    case assertSubenvEmpty sub of { Refl -> +    case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro sdElt e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 SETop e2 -> +    case lemAppendNil @e_binds of { Refl ->      let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in -    let collectexpr = bindingsCollect e0 subtapeE in -    Ret (BTop `BPush` (shty, letBinds she0 she1) -              `BPush` (STArr ndim (STPair (d1 eltty) tapety) -                      ,EBuild ext ndim -                         (EVar ext shty IZ) -                         (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil) -                                                                              &. #sh (shty `SCons` SNil) -                                                                              &. #d1env (desD1E des) -                                                                              &. #d1env' (desD1E usedDes)) -                                                                             (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) -                                                                             (#ix :++: #sh :++: #d1env)) -                                                                   e0)) $ -                            let w = autoWeak (#ix (shty `SCons` SNil) -                                              &. #sh (shty `SCons` SNil) -                                              &. #e0 (bindingsBinds e0) -                                              &. #d1env (desD1E des) -                                              &. #d1env' (desD1E usedDes)) -                                             (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) -                                             (#e0 :++: #ix :++: #sh :++: #d1env) -                            in EPair ext (weakenExpr w e1) (collectexpr w))) -              `BPush` (STArr ndim tapety, emap (ESnd ext (EVar ext (STPair (d1 eltty) tapety) IZ)) -                                               (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ))) -        (SEYes (SENo (SEYes SETop))) -        (emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ)) -              (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) -        (subenvCompose subMergeUsed proSub) -        (let sinkOverEnvPro = wSinks @(TArr ndim (D2 eltty) : D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in -         EMaybe ext -           (zeroTup envPro) -           (ESnd ext $ -              uninvertTup (d2e envPro) (STArr ndim STNil) $ -                makeAccumulators @_ @_ @(TArr ndim TNil) envPro $ -                  EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS (IS IZ)))) $ -                    -- the cotangent for this element -                    ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ)) -                                       (EVar ext shty IZ)) $ -                    -- the tape for this element -                    ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS (IS IZ))) -                                       (EVar ext shty (IS IZ))) $ -                    let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ -                    in letBinds rebinds $ -                         weakenExpr (autoWeak (#d (auto1 @(D2 eltty)) -                                               &. #pro (d2ace envPro) -                                               &. #etape (subList (bindingsBinds e0) subtapeE) -                                               &. #prerebinds prerebinds -                                               &. #tape (auto1 @(Tape e_tape)) -                                               &. #ix (auto1 @shty) -                                               &. #darr (auto1 @(TArr ndim (D2 eltty))) -                                               &. #mdarr (auto1 @(TMaybe (TArr ndim (D2 eltty)))) -                                               &. #tapearr (auto1 @(TArr ndim (Tape e_tape))) -                                               &. #sh (auto1 @shty) -                                               &. #d2acUsed (d2ace (select SAccum usedDes)) -                                               &. #d2acEnv (d2ace (select SAccum des))) -                                              (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) -                                              ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #mdarr :++: #tapearr :++: #sh :++: #d2acEnv) -                                     .> wPro (subList (bindingsBinds e0) subtapeE)) -                                    e2) -           (EVar ext (d2 (STArr ndim eltty)) IZ)) +    let collectexpr = bindingsCollectTape (bindingsBinds e0) subtapeE in +    let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in +    let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in +    Ret (mergePrimalBindings +          `bpush` weakenExpr (wSinks (d1e envPro)) (drevPrimal des she) +          `bpush` EBuild ext ndim +                    (EVar ext shty IZ) +                    (letBinds (fst (weakenBindingsE (autoWeak (#ix (shty `SCons` SNil) +                                                               &. #sh (shty `SCons` SNil) +                                                               &. #propr (d1e envPro) +                                                               &. #d1env (desD1E des) +                                                               &. #d1env' (desD1E usedDes)) +                                                              (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) +                                                              (#ix :++: #sh :++: #propr :++: #d1env)) +                                                    e0)) $ +                       let w = autoWeak (#ix (shty `SCons` SNil) +                                         &. #sh (shty `SCons` SNil) +                                         &. #e0 (bindingsBinds e0) +                                         &. #propr (d1e envPro) +                                         &. #d1env (desD1E des) +                                         &. #d1env' (desD1E usedDes)) +                                        (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) +                                        (#e0 :++: #ix :++: #sh :++: #propr :++: #d1env) +                           w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env')) +                       in EPair ext (weakenExpr w e1) (collectexpr w')) +          `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ)) +        (SEYesR (SENo (SEYesR (subenvAll (d1e envPro))))) +        (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) +        (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) +        (let sinkOverEnvPro = wSinks @(sd : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace envPro) in +         ESnd ext $ +           uninvertTup (d2e envPro) (STArr ndim STNil) $ +             makeAccumulators @_ @_ @(TArr ndim TNil) (WSink .> WSink .> WSink .> wRaiseAbove (d1e envPro) (d2ace (select SAccum des))) envPro $ +               EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $ +                 -- the cotangent for this element +                 ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ)) +                                    (EVar ext shty IZ)) $ +                 -- the tape for this element +                 ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) +                                    (EVar ext shty (IS IZ))) $ +                 let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) +                 in letBinds (rebinds IZ) $ +                      weakenExpr (autoWeak (#d (auto1 @sdElt) +                                            &. #pro (d2ace envPro) +                                            &. #etape (subList (bindingsBinds e0) subtapeE) +                                            &. #prerebinds prerebinds +                                            &. #tape (auto1 @(Tape e_tape)) +                                            &. #ix (auto1 @shty) +                                            &. #darr (auto1 @(TArr ndim sdElt)) +                                            &. #tapearr (auto1 @(TArr ndim (Tape e_tape))) +                                            &. #sh (auto1 @shty) +                                            &. #propr (d1e envPro) +                                            &. #d2acUsed (d2ace (select SAccum usedDes)) +                                            &. #d2acEnv (d2ace (select SAccum des))) +                                           (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) +                                           ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #propr :++: #d2acEnv) +                                  .> wPro (subList (bindingsBinds e0) subtapeE)) +                                 e2)      }}    EUnit _ e -    | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> +    | SpArr sdElt <- sd +    , Ret e0 subtape e1 sub e2 <- drev des accumMap sdElt e ->      Ret e0          subtape          (EUnit ext e1)          sub -        (EMaybe ext -          (zeroTup (subList (select SMerge des) sub)) -          (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $ -             weakenExpr (WCopy (WSink .> WSink)) e2) -          (EVar ext (STMaybe (STArr SZ (d2 (typeOf e)))) IZ)) +        (ELet ext (EIdx0 ext (EVar ext (STArr SZ (applySparse sdElt (d2 (typeOf e)))) IZ)) $ +           weakenExpr (WCopy WSink) e2)    EReplicate1Inner _ en e -    -- We're allowed to ignore en2 here because the output of 'ei' is discrete. -    | Rets binds subtape (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil) -        <- retConcat des $ drev des accumMap en `SCons` drev des accumMap e `SCons` SNil +    -- We're allowed to differentiate 'en' as primal-only here because its output is discrete. +    | SpArr sdElt <- sd      , let STArr ndim eltty = typeOf e -> -    Ret binds -        subtape -        (EReplicate1Inner ext en1 e1) -        sub -        (EMaybe ext -          (zeroTup (subList (select SMerge des) sub)) -          (ELet ext (EJust ext (EFold1Inner ext Commut -                        (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) -                        (ezeroD2 eltty) -                        (EVar ext (STArr (SS ndim) (d2 eltty)) IZ))) $ -            weakenExpr (WCopy (WSink .> WSink)) e2) -          (EVar ext (d2 (STArr (SS ndim) eltty)) IZ)) +    -- This pessimistic sparsity union is because the array might have been empty, in which case we need to generate a zero. +    sparsePlusS ST ST (d2M eltty) sdElt SpAbsent $ \sdElt' (Inj inj1) (Inj inj2) _ -> +    case drev des accumMap (SpArr sdElt') e of { Ret binds subtape e1 sub e2 -> +      Ret binds +          subtape +          (EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1) +          sub +          (ELet ext (EFold1Inner ext Commut +                         (sparsePlus (d2M eltty) sdElt' +                            (EVar ext (applySparse sdElt' (d2 eltty)) (IS IZ)) +                            (EVar ext (applySparse sdElt' (d2 eltty)) IZ)) +                         (inj2 (ENil ext)) +                         (emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $ +             weakenExpr (WCopy WSink) e2) +    }    EIdx0 _ e -    | Ret e0 subtape e1 sub e2 <- drev des accumMap e +    | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd) e      , STArr _ t <- typeOf e ->      Ret e0          subtape          (EIdx0 ext e1)          sub -        (ELet ext (EJust ext (EUnit ext (EVar ext (d2 t) IZ))) $ -         weakenExpr (WCopy WSink) e2) +        (ELet ext (EUnit ext (EVar ext (applySparse sd (d2 t)) IZ)) $ +           weakenExpr (WCopy WSink) e2)    EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead"    {- @@ -956,9 +1163,9 @@ drev des accumMap = \case      | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)          <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil      , STArr (SS n) eltty <- typeOf e -> -    Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1) -               `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ))) -        (SEYes (SENo subtape)) +    Ret (binds `bpush` e1 +               `bpush` EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ)) +        (SEYesR (SENo subtape))          (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ))                     (weakenExpr (WSink .> WSink) ei1))          sub @@ -969,57 +1176,58 @@ drev des accumMap = \case    -}    EIdx _ e ei -    -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. -    | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) -        <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil -    , STArr n eltty <- typeOf e +    -- We're allowed to differentiate ei as primal because its output is discrete. +    | STArr n eltty <- typeOf e      , Refl <- indexTupD1Id n -    , Refl <- lemZeroInfoD2 eltty -    , let tIxN = tTup (sreplicate n tIx)  -> -    Ret (binds `BPush` (STArr n (d1 eltty), e1) -               `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ)) -               `BPush` (tIxN, weakenExpr (WSink .> WSink) ei1)) -        (SEYes (SEYes (SENo subtape))) -        (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) -                  (EVar ext (tTup (sreplicate n tIx)) IZ)) -        sub -        (ELet ext (EOneHot ext (d2M (STArr n eltty)) (SAPJust (SAPArrIdx SAPHere)) -                             (EPair ext (EPair ext (EVar ext tIxN (IS IZ)) -                                                   (EBuild ext n (EVar ext tIxN (IS (IS IZ))) (ENil ext))) -                                        (ENil ext)) -                             (EVar ext (d2 eltty) IZ)) $ -         weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) +    , let tIxN = tTup (sreplicate n tIx) -> +    sparsePlusS ST ST (d2M eltty) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> +    case drev des accumMap (SpArr sd') e of { Ret binds subtape e1 sub e2 -> +      Ret (binds `bpush` e1 +                 `bpush` EShape ext (EVar ext (typeOf e1) IZ) +                 `bpush` weakenExpr (WSink .> WSink .> wSinks (bindingsBinds binds)) (drevPrimal des ei)) +          (SEYesR (SEYesR (SENo subtape))) +          (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) +                    (EVar ext (tTup (sreplicate n tIx)) IZ)) +          sub +          (ELet ext +            (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty))) +                         (SAPArrIdx SAPHere) +                         (EPair ext +                           (EPair ext (EVar ext tIxN (IS IZ)) +                                      (EBuild ext n (EVar ext tIxN (IS (IS IZ))) $ +                                         makeZeroInfo (applySparse sd' (d2M eltty)) (inj2 (ENil ext)))) +                           (ENil ext)) +                         (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $ +           weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) +    }    EShape _ e -    -- Allowed to ignore e2 here because the output of EShape is discrete, -    -- hence we'd be passing a zero cotangent to e2 anyway. -    | Ret e0 subtape e1 _ _ <- drev des accumMap e -    , STArr n _ <- typeOf e +    -- Allowed to differentiate e as primal because the output of EShape is +    -- discrete, hence we'd be passing a zero cotangent to e anyway. +    | STArr n _ <- typeOf e      , Refl <- indexTupD1Id n -> -    Ret e0 -        subtape -        (EShape ext e1) -        (subenvNone (select SMerge des)) +    Ret BTop +        SETop +        (EShape ext (drevPrimal des e)) +        (subenvNone (d2eM (select SMerge des)))          (ENil ext)    ESum1Inner _ e -    | Ret e0 subtape e1 sub e2 <- drev des accumMap e +    | SpArr sd' <- sd +    , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e      , STArr (SS n) t <- typeOf e -> -    Ret (e0 `BPush` (STArr (SS n) t, e1) -            `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ))) -        (SEYes (SENo subtape)) +    Ret (e0 `bpush` e1 +            `bpush` EShape ext (EVar ext (STArr (SS n) t) IZ)) +        (SEYesR (SENo subtape))          (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ)))          sub -        (EMaybe ext -          (zeroTup (subList (select SMerge des) sub)) -          (ELet ext (EJust ext (EReplicate1Inner ext -                                  (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS (IS IZ)))) -                                  (EVar ext (STArr n (d2 t)) IZ))) $ -           weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) -          (EVar ext (d2 (STArr n t)) IZ)) +        (ELet ext (EReplicate1Inner ext +                    (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) +                    (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ +         weakenExpr (WCopy (WSink .> WSink)) e2) -  EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e -  EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e +  EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e +  EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e    -- These should be the next to be implemented, I think    EFold1Inner{} -> err_unsupported "EFold1Inner" @@ -1033,8 +1241,8 @@ drev des accumMap = \case    ELCase{} -> err_unsupported "ELCase"    EWith{} -> err_accum -  EAccum{} -> err_accum    EZero{} -> err_monoid +  EDeepZero{} -> err_monoid    EPlus{} -> err_monoid    EOneHot{} -> err_monoid @@ -1043,94 +1251,116 @@ drev des accumMap = \case      err_monoid = error "Monoid operations unsupported in the source program"      err_unsupported s = error $ "CHAD: unsupported " ++ s -    deriv_extremum :: ScalIsNumeric t' ~ True -                   => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t'))) -                   -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t')) -    deriv_extremum extremum e -      | Ret e0 subtape e1 sub e2 <- drev des accumMap e -      , at@(STArr (SS n) t@(STScal st)) <- typeOf e -      , let at' = STArr n t -      , let tIxN = tTup (sreplicate (SS n) tIx) = -      Ret (e0 `BPush` (at, e1) -              `BPush` (at', extremum (EVar ext at IZ))) -          (SEYes (SEYes subtape)) -          (EVar ext at' IZ) -          sub -          (EMaybe ext -            (zeroTup (subList (select SMerge des) sub)) -            (ELet ext (EJust ext -                        (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS (IS IZ))))) $ -                           eif (EOp ext (OEq st) (EPair ext -                                        (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ)) -                                        (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ))))) -                             (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) -                             (ezeroD2 t))) $ -              weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2) -            (EVar ext (d2 at') IZ)) +    contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs) +    contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub)) + +deriv_extremum :: (?config :: CHADConfig, ScalIsNumeric t ~ True) +               => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) +               -> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) +               -> Sparse (D2s t) sd +               -> Expr ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t)) +deriv_extremum extremum des accumMap sd e +  | at@(STArr (SS n) t@(STScal st)) <- typeOf e +  , let at' = STArr n t +  , let tIxN = tTup (sreplicate (SS n) tIx) = +  sparsePlusS ST ST (d2M t) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> +  case drev des accumMap (SpArr sd') e of { Ret e0 subtape e1 sub e2 -> +    Ret (e0 `bpush` e1 +            `bpush` extremum (EVar ext at IZ)) +        (SEYesR (SEYesR subtape)) +        (EVar ext at' IZ) +        sub +        (ELet ext +           (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $ +              eif (EOp ext (OEq st) (EPair ext +                           (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ)) +                           (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ))))) +                (inj1 $ EIdx ext (EVar ext (STArr n (applySparse sd (d2 t))) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) +                (inj2 (ENil ext))) $ +         weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) +  }  data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s) -data RetScoped env0 sto a s t = -  forall shbinds tapebinds env0Merge. +data RetScoped env0 sto a s sd t = +  forall shbinds tapebinds contribs sa.      RetScoped          (Bindings Ex (D1E (a : env0)) shbinds)  -- shared binds -        (Subenv shbinds tapebinds) +        (Subenv (Append shbinds '[D1 a]) tapebinds)          (Ex (Append shbinds (D1E (a : env0))) (D1 t)) -        (Subenv (Select env0 sto "merge") env0Merge) +        (SubenvS (D2E (Select env0 sto "merge")) contribs)             -- ^ merge contributions to the _enclosing_ merge environment -        (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) -           (If (s == "discr") (Tup (D2E env0Merge)) -                              (TPair (Tup (D2E env0Merge)) (D2 a)))) +        (Sparse (D2 a) sa) +           -- ^ contribution to the argument +        (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) +            (If (s == "discr") (Tup contribs) +                               (TPair (Tup contribs) sa)))            -- ^ the merge contributions, plus the cotangent to the argument            -- (if there is any) -deriving instance Show (RetScoped env0 sto a s t) +deriving instance Show (RetScoped env0 sto a s sd t) -drevScoped :: forall a s env sto t. +drevScoped :: forall a s env sto sd t.                (?config :: CHADConfig)             => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))             -> STy a -> Storage s -> Maybe (ValId a) +           -> Sparse (D2 t) sd             -> Expr ValId (a : env) t -           -> RetScoped env sto a s t -drevScoped des accumMap argty argsto argids expr = case argsto of +           -> RetScoped env sto a s sd t +drevScoped des accumMap argty argsto argids sd expr = case argsto of    SMerge -    | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr -> +    | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr +    , Refl <- lemAppendNil @tapebinds ->          case sub of -          SEYes sub' -> RetScoped e0 subtape e1 sub' e2 -          SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (ezeroD2 argty)) +          SEYes sp sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' sp e2 +          SENo sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' SpAbsent (EPair ext e2 (ENil ext))    SAccum -    | Just (VIArr i _) <- argids +    | chcSmartWith ?config +    , Just (VIArr i _) <- argids      , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap      , Just Refl <- testEquality foundTy (STAccum (d2M argty)) -    , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) expr -> -        RetScoped e0 subtape e1 sub $ +    , Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) sd expr +    , Refl <- lemAppendNil @tapebinds -> +        -- Our contribution to the binding's cotangent _here_ is zero (absent), +        -- because we're contributing to an earlier binding of the same value +        -- instead. +        RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent $            let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in            ELet ext (EVar ext (STAccum (d2M argty)) (WSink .> wtapebinds @> idx)) $ -            weakenExpr (autoWeak (#d (auto1 @(D2 t)) +            weakenExpr (autoWeak (#d (auto1 @sd)                                      &. #body (subList (bindingsBinds e0) subtape)                                      &. #ac (auto1 @(TAccum (D2 a)))                                      &. #tl (d2ace (select SAccum des)))                                     (#d :++: #body :++: #ac :++: #tl)                                     (#ac :++: #d :++: #body :++: #tl)) -                       -- Our contribution to the binding's cotangent _here_ is -                       -- zero, because we're contributing to an earlier binding -                       -- of the same value instead. -                       (EPair ext e2 (ezeroD2 argty)) +                       (EPair ext e2 (ENil ext))      | let accumMap' = case argids of                          Just (VIArr i _) -> VarMap.insert i (STAccum (d2M argty)) IZ (VarMap.sink1 accumMap)                          _ -> VarMap.sink1 accumMap -    , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' expr -> -        RetScoped e0 subtape e1 sub $ -          EWith ext (d2M argty) (ezeroD2 argty) $ -            weakenExpr (autoWeak (#d (auto1 @(D2 t)) -                                  &. #body (subList (bindingsBinds e0) subtape) -                                  &. #ac (auto1 @(TAccum (D2 a))) -                                  &. #tl (d2ace (select SAccum des))) +    , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' sd expr -> +        let library = #d (auto1 @sd) +                      &. #p (auto1 @(D1 a)) +                      &. #body (subList (bindingsBinds e0) subtape) +                      &. #ac (auto1 @(TAccum (D2 a))) +                      &. #tl (d2ace (select SAccum des)) +        in +        RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub (spDense (d2M argty)) $ +          let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in +          EWith ext (d2M argty) (EDeepZero ext (d2M argty) (d2deepZeroInfo argty (EVar ext (d1 argty) primalIdx))) $ +            weakenExpr (autoWeak library                                   (#d :++: #body :++: #ac :++: #tl) -                                 (#ac :++: #d :++: #body :++: #tl)) +                                 (#ac :++: #d :++: (#body :++: #p) :++: #tl))                         e2    SDiscr -    | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr -> -        RetScoped e0 subtape e1 sub e2 +    | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr +    , Refl <- lemAppendNil @tapebinds -> +        RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2 + +-- TODO: proper primal-only transform that doesn't depend on D1 = Id +drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t) +drevPrimal des e +  | Refl <- d1Identity (typeOf e) +  , Refl <- d1eIdentity (descrList des) +  = mapExt (const ext) e | 
