aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/Drev.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-10 21:49:45 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-10 21:50:25 +0100
commit174af2ba568de66e0d890825b8bda930b8e7bb96 (patch)
tree5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/CHAD/Drev.hs
parent92bca235e3aaa287286b6af082d3fce585825a35 (diff)
Move module hierarchy under CHAD.
Diffstat (limited to 'src/CHAD/Drev.hs')
-rw-r--r--src/CHAD/Drev.hs1583
1 files changed, 1583 insertions, 0 deletions
diff --git a/src/CHAD/Drev.hs b/src/CHAD/Drev.hs
new file mode 100644
index 0000000..595d3c7
--- /dev/null
+++ b/src/CHAD/Drev.hs
@@ -0,0 +1,1583 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ImplicitParams #-}
+{-# LANGUAGE ImpredicativeTypes #-}
+{-# LANGUAGE OverloadedLabels #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeData #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+
+-- I want to bring various type variables in scope using type annotations in
+-- patterns, but I don't want to have to mention all the other type parameters
+-- of the types in question as well then. Partial type signatures (with '_') are
+-- useful here.
+{-# LANGUAGE PartialTypeSignatures #-}
+{-# OPTIONS -Wno-partial-type-signatures #-}
+module CHAD.Drev (
+ drev,
+ freezeRet,
+ CHADConfig(..),
+ defaultConfig,
+ Storage(..),
+ Descr(..),
+ Select,
+) where
+
+import Data.Functor.Const
+import Data.Some
+import Data.Type.Equality (type (==), testEquality)
+
+import CHAD.Analysis.Identity (ValId(..), validSplitEither)
+import CHAD.AST
+import CHAD.AST.Bindings
+import CHAD.AST.Count
+import CHAD.AST.Env
+import CHAD.AST.Sparse
+import CHAD.AST.Weaken.Auto
+import CHAD.Data
+import qualified CHAD.Data.VarMap as VarMap
+import CHAD.Data.VarMap (VarMap)
+import CHAD.Drev.Accum
+import CHAD.Drev.EnvDescr
+import CHAD.Drev.Types
+import CHAD.Lemmas
+
+
+------------------------------ TAPES AND BINDINGS ------------------------------
+
+type family Tape binds where
+ Tape '[] = TNil
+ Tape (t : ts) = TPair t (Tape ts)
+
+tapeTy :: SList STy binds -> STy (Tape binds)
+tapeTy SNil = STNil
+tapeTy (SCons t ts) = STPair t (tapeTy ts)
+
+bindingsCollectTape :: SList STy binds -> Subenv binds tapebinds
+ -> binds :> env2 -> Ex env2 (Tape tapebinds)
+bindingsCollectTape SNil SETop _ = ENil ext
+bindingsCollectTape (t `SCons` binds) (SEYesR sub) w =
+ EPair ext (EVar ext t (w @> IZ))
+ (bindingsCollectTape binds sub (w .> WSink))
+bindingsCollectTape (_ `SCons` binds) (SENo sub) w =
+ bindingsCollectTape binds sub (w .> WSink)
+
+-- bindingsCollectTape' :: forall f env binds tapebinds env2. Bindings f env binds -> Subenv binds tapebinds
+-- -> Append binds env :> env2 -> Ex env2 (Tape tapebinds)
+-- bindingsCollectTape' binds sub w
+-- | Refl <- lemAppendNil @binds
+-- = bindingsCollectTape (bindingsBinds binds) sub (w .> wCopies @_ @_ @'[] (bindingsBinds binds) (WClosed @env))
+
+-- In order from large to small: i.e. in reverse order from what we want,
+-- because in a Bindings, the head of the list is the bottom-most entry.
+type family TapeUnfoldings binds where
+ TapeUnfoldings '[] = '[]
+ TapeUnfoldings (t : ts) = Tape ts : TapeUnfoldings ts
+
+type family Reverse l where
+ Reverse '[] = '[]
+ Reverse (t : ts) = Append (Reverse ts) '[t]
+
+-- An expression that is always 'snd'
+data UnfExpr env t where
+ UnfExSnd :: STy s -> STy t -> UnfExpr (TPair s t : env) t
+
+fromUnfExpr :: UnfExpr env t -> Ex env t
+fromUnfExpr (UnfExSnd s t) = ESnd ext (EVar ext (STPair s t) IZ)
+
+-- - A bunch of 'snd' expressions taking us from knowing that there's a
+-- 'Tape ts' in the environment (for simplicity assume it's at IZ, we'll fix
+-- this in reconstructBindings), to having 'Reverse (TapeUnfoldings ts)' in
+-- the environment.
+-- - In the extended environment, another bunch of let bindings (these are
+-- 'fst' expressions, but no need to know that statically) that project the
+-- fsts out of what we introduced above, one for each type in 'ts'.
+data Reconstructor env ts =
+ Reconstructor
+ (Bindings UnfExpr (Tape ts : env) (Reverse (TapeUnfoldings ts)))
+ (Bindings Ex (Append (Reverse (TapeUnfoldings ts)) (Tape ts : env)) ts)
+
+ssnoc :: SList f ts -> f t -> SList f (Append ts '[t])
+ssnoc SNil a = SCons a SNil
+ssnoc (SCons t ts) a = SCons t (ssnoc ts a)
+
+sreverse :: SList f ts -> SList f (Reverse ts)
+sreverse SNil = SNil
+sreverse (SCons t ts) = ssnoc (sreverse ts) t
+
+stapeUnfoldings :: SList STy ts -> SList STy (TapeUnfoldings ts)
+stapeUnfoldings SNil = SNil
+stapeUnfoldings (SCons _ ts) = SCons (tapeTy ts) (stapeUnfoldings ts)
+
+-- Puts a 'snd' at the top of an unfolder stack and grows the context variable by one.
+shiftUnfolder
+ :: STy t
+ -> SList STy ts
+ -> Bindings UnfExpr (Tape ts : env) list
+ -> Bindings UnfExpr (Tape (t : ts) : env) (Append list '[Tape ts])
+shiftUnfolder newTy ts BTop = BPush BTop (tapeTy ts, UnfExSnd newTy (tapeTy ts))
+shiftUnfolder newTy ts (BPush b (t, UnfExSnd itemTy _)) =
+ -- Recurse on 'b', and retype the 'snd'. We need to unfold 'b' once in order
+ -- to expand an 'Append' in the types so that things simplify just enough.
+ -- We have an equality 'Append binds x1 ~ a : x2', where 'binds' is the list
+ -- of bindings produced by 'b'. We want to conclude from this that
+ -- 'binds ~ a : x3' for some 'x3', but GHC will only do that once we know
+ -- that 'binds ~ y : ys' so that the 'Append' can expand one step, after
+ -- which 'y ~ a' as desired. The 'case' unfolds 'b' one step.
+ BPush (shiftUnfolder newTy ts b) (t, case b of BTop -> UnfExSnd itemTy t
+ BPush{} -> UnfExSnd itemTy t)
+
+growRecon :: forall env t ts. STy t -> SList STy ts -> Reconstructor env ts -> Reconstructor env (t : ts)
+growRecon t ts (Reconstructor unfbs bs)
+ | Refl <- lemAppendNil @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts])
+ , Refl <- lemAppendAssoc @ts @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts]) @(Tape (t : ts) : env)
+ , Refl <- lemAppendAssoc @(Reverse (TapeUnfoldings ts)) @'[Tape ts] @env
+ = Reconstructor
+ (shiftUnfolder t ts unfbs)
+ -- Add a 'fst' at the bottom of the builder stack.
+ -- First we have to weaken most of 'bs' to skip one more binding in the
+ -- unfolder stack above it.
+ (BPush (fst (weakenBindingsE
+ (wCopies (sappend (sreverse (stapeUnfoldings ts)) (SCons (tapeTy ts) SNil))
+ (WSink :: env :> (Tape (t : ts) : env))) bs))
+ (t
+ ,EFst ext $ EVar ext (tapeTy (SCons t ts)) $
+ wSinks @(Tape (t : ts) : env)
+ (sappend ts
+ (sappend (sappend (sreverse (stapeUnfoldings ts))
+ (SCons (tapeTy ts) SNil))
+ SNil))
+ @> IZ))
+
+buildReconstructor :: SList STy ts -> Reconstructor env ts
+buildReconstructor SNil = Reconstructor BTop BTop
+buildReconstructor (SCons t ts) = growRecon t ts (buildReconstructor ts)
+
+-- STRATEGY FOR reconstructBindings
+--
+-- binds = []
+-- e : ()
+--
+-- binds = [c]
+-- e : (c, ())
+-- x0 = snd x1 : ()
+-- y1 = fst e : c
+--
+-- binds = [b, c]
+-- e : (b, (c, ()))
+-- x1 = snd e : (c, ())
+-- x0 = snd x1 : ()
+-- y1 = fst x1 : c
+-- y2 = fst x2 : b
+--
+-- binds = [a, b, c]
+-- e : (a, (b, (c, ())))
+-- x2 = snd e : (b, (c, ()))
+-- x1 = snd x2 : (c, ())
+-- x0 = snd x1 : ()
+-- y1 = fst x1 : c
+-- y2 = fst x2 : b
+-- y3 = fst x3 : a
+
+-- Given that in 'env' we can find a 'Tape binds', i.e. a tuple containing all
+-- the things in the list 'binds', we want to create a let stack that extracts
+-- all values from that tuple and in effect "restores" the environment
+-- described by 'binds'. The idea is that elsewhere, we took a slice of the
+-- environment and saved it all in a tuple to be restored later. We
+-- incidentally also add a bunch of additional bindings, namely 'Reverse
+-- (TapeUnfoldings binds)', so the calling code just has to skip those in
+-- whatever it wants to do.
+reconstructBindings :: SList STy binds
+ -> (forall env. Idx env (Tape binds) -> Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds)))
+ ,SList STy (Reverse (TapeUnfoldings binds)))
+reconstructBindings binds =
+ (\tape -> let Reconstructor unf build = buildReconstructor binds
+ in fst $ weakenBindingsE (WIdx tape)
+ (bconcat (mapBindings fromUnfExpr unf) build)
+ ,sreverse (stapeUnfoldings binds))
+
+
+---------------------------------- DERIVATIVES ---------------------------------
+
+d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t)
+d1op (OAdd t) e = EOp ext (OAdd t) e
+d1op (OMul t) e = EOp ext (OMul t) e
+d1op (ONeg t) e = EOp ext (ONeg t) e
+d1op (OLt t) e = EOp ext (OLt t) e
+d1op (OLe t) e = EOp ext (OLe t) e
+d1op (OEq t) e = EOp ext (OEq t) e
+d1op ONot e = EOp ext ONot e
+d1op OAnd e = EOp ext OAnd e
+d1op OOr e = EOp ext OOr e
+d1op OIf e = EOp ext OIf e
+d1op ORound64 e = EOp ext ORound64 e
+d1op OToFl64 e = EOp ext OToFl64 e
+d1op (ORecip t) e = EOp ext (ORecip t) e
+d1op (OExp t) e = EOp ext (OExp t) e
+d1op (OLog t) e = EOp ext (OLog t) e
+d1op (OIDiv t) e = EOp ext (OIDiv t) e
+d1op (OMod t) e = EOp ext (OMod t) e
+
+-- | Both primal and dual must be duplicable expressions
+data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
+ | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a))
+
+d2op :: SOp a t -> D2Op a t
+d2op op = case op of
+ OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EPair ext d d
+ OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d ->
+ EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
+ (EOp ext (OMul t) (EPair ext (EFst ext e) d))
+ ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d
+ OLt t -> Linear $ \_ -> pairZero t
+ OLe t -> Linear $ \_ -> pairZero t
+ OEq t -> Linear $ \_ -> pairZero t
+ ONot -> Linear $ \_ -> ENil ext
+ OAnd -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ OOr -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ OIf -> Linear $ \_ -> ENil ext
+ ORound64 -> Linear $ \_ -> EZero ext (SMTScal STF64) (ENil ext)
+ OToFl64 -> Linear $ \_ -> ENil ext
+ ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d)
+ OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d)
+ OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d)
+ OIDiv t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ OMod t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ where
+ pairZero :: SScalTy a -> Ex env (D2 (TPair (TScal a) (TScal a)))
+ pairZero t = ziNil t $ EPair ext (EZero ext (d2M (STScal t)) (ENil ext))
+ (EZero ext (d2M (STScal t)) (ENil ext))
+ where
+ ziNil :: SScalTy a -> (ZeroInfo (D2s a) ~ TNil => r) -> r
+ ziNil STI32 k = k
+ ziNil STI64 k = k
+ ziNil STF32 k = k
+ ziNil STF64 k = k
+ ziNil STBool k = k
+
+ d2opUnArrangeInt :: SScalTy a
+ -> (D2s a ~ TScal a => D2Op (TScal a) t)
+ -> D2Op (TScal a) t
+ d2opUnArrangeInt ty float = case ty of
+ STI32 -> Linear $ \_ -> ENil ext
+ STI64 -> Linear $ \_ -> ENil ext
+ STF32 -> float
+ STF64 -> float
+ STBool -> Linear $ \_ -> ENil ext
+
+ d2opBinArrangeInt :: SScalTy a
+ -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t)
+ -> D2Op (TPair (TScal a) (TScal a)) t
+ d2opBinArrangeInt ty float = case ty of
+ STI32 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ STI64 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ STF32 -> float
+ STF64 -> float
+ STBool -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+
+ floatingD2 :: ScalIsFloating a ~ True
+ => SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r
+ floatingD2 STF32 k = k
+ floatingD2 STF64 k = k
+
+ integralD2 :: ScalIsIntegral a ~ True
+ => SScalTy a -> ((D2s a ~ TNil, ScalIsNumeric a ~ True) => r) -> r
+ integralD2 STI32 k = k
+ integralD2 STI64 k = k
+
+desD1E :: Descr env sto -> SList STy (D1E env)
+desD1E = d1e . descrList
+
+-- d1W :: env :> env' -> D1E env :> D1E env'
+-- d1W WId = WId
+-- d1W WSink = WSink
+-- d1W (WCopy w) = WCopy (d1W w)
+-- d1W (WPop w) = WPop (d1W w)
+-- d1W (WThen u w) = WThen (d1W u) (d1W w)
+
+conv1Idx :: Idx env t -> Idx (D1E env) (D1 t)
+conv1Idx IZ = IZ
+conv1Idx (IS i) = IS (conv1Idx i)
+
+data Idx2 env sto t
+ = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t)))
+ | Idx2Me (Idx (D2E (Select env sto "merge")) (D2 t))
+ | Idx2Di (Idx (Select env sto "discr") t)
+
+conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t
+conv2Idx (DPush _ (_, _, SAccum)) IZ = Idx2Ac IZ
+conv2Idx (DPush _ (_, _, SMerge)) IZ = Idx2Me IZ
+conv2Idx (DPush _ (_, _, SDiscr)) IZ = Idx2Di IZ
+conv2Idx (DPush des (_, _, SAccum)) (IS i) =
+ case conv2Idx des i of Idx2Ac j -> Idx2Ac (IS j)
+ Idx2Me j -> Idx2Me j
+ Idx2Di j -> Idx2Di j
+conv2Idx (DPush des (_, _, SMerge)) (IS i) =
+ case conv2Idx des i of Idx2Ac j -> Idx2Ac j
+ Idx2Me j -> Idx2Me (IS j)
+ Idx2Di j -> Idx2Di j
+conv2Idx (DPush des (_, _, SDiscr)) (IS i) =
+ case conv2Idx des i of Idx2Ac j -> Idx2Ac j
+ Idx2Me j -> Idx2Me j
+ Idx2Di j -> Idx2Di (IS j)
+conv2Idx DTop i = case i of {}
+
+opt2UnSparse :: SOp a b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b)
+opt2UnSparse = go . opt2
+ where
+ go :: STy b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b)
+ go (STScal STI32) SpAbsent = \_ -> ENil ext
+ go (STScal STI64) SpAbsent = \_ -> ENil ext
+ go (STScal STF32) SpAbsent = \_ -> EZero ext (SMTScal STF32) (ENil ext)
+ go (STScal STF64) SpAbsent = \_ -> EZero ext (SMTScal STF64) (ENil ext)
+ go (STScal STBool) SpAbsent = \_ -> ENil ext
+ go (STScal STF32) SpScal = id
+ go (STScal STF64) SpScal = id
+ go STNil _ = \_ -> ENil ext
+ go (STPair t1 t2) (SpPair s1 s2) = \e -> eunPair e $ \_ e1 e2 -> EPair ext (go t1 s1 e1) (go t2 s2 e2)
+ go t _ = error $ "Primitive operations that return " ++ show t ++ " are scary"
+
+
+----------------------------------- SPARSITY -----------------------------------
+
+expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a)
+expandSparse t sp _ e | Just Refl <- isDense (d2M t) sp = e
+expandSparse t (SpSparse sp) epr e =
+ EMaybe ext
+ (EZero ext (d2M t) (d2zeroInfo t epr))
+ (expandSparse t sp (weakenExpr WSink epr) (EVar ext (applySparse sp (d2 t)) IZ))
+ e
+expandSparse t SpAbsent epr _ = EZero ext (d2M t) (d2zeroInfo t epr)
+expandSparse (STPair t1 t2) (SpPair s1 s2) epr e =
+ eunPair epr $ \w1 epr1 epr2 ->
+ eunPair (weakenExpr w1 e) $ \w2 e1 e2 ->
+ EPair ext (expandSparse t1 s1 (weakenExpr w2 epr1) e1)
+ (expandSparse t2 s2 (weakenExpr w2 epr2) e2)
+expandSparse (STEither t1 t2) (SpLEither s1 s2) epr e =
+ ELCase ext e
+ (EZero ext (d2M (STEither t1 t2)) (ENil ext))
+ (ECase ext (weakenExpr WSink epr)
+ (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ))))
+ (EError ext (d2 (STEither t1 t2)) "expspa r<-dl"))
+ (ECase ext (weakenExpr WSink epr)
+ (EError ext (d2 (STEither t1 t2)) "expspa l<-dr")
+ (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ)))))
+expandSparse (STLEither t1 t2) (SpLEither s1 s2) epr e =
+ ELCase ext e
+ (EZero ext (d2M (STEither t1 t2)) (ENil ext))
+ (ELCase ext (weakenExpr WSink epr)
+ (EError ext (d2 (STEither t1 t2)) "expspa ln<-dl")
+ (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ))))
+ (EError ext (d2 (STEither t1 t2)) "expspa lr<-dl"))
+ (ELCase ext (weakenExpr WSink epr)
+ (EError ext (d2 (STEither t1 t2)) "expspa ln<-dr")
+ (EError ext (d2 (STEither t1 t2)) "expspa ll<-dr")
+ (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ)))))
+expandSparse (STMaybe t) (SpMaybe s) epr e =
+ EMaybe ext
+ (ENothing ext (d2 t))
+ (let epr' = EMaybe ext (EError ext (d1 t) "expspa n<-dj") (EVar ext (d1 t) IZ) epr
+ in EJust ext (expandSparse t s (weakenExpr WSink epr') (EVar ext (applySparse s (d2 t)) IZ)))
+ e
+expandSparse (STArr _ t) (SpArr s) epr e =
+ ezipWith (expandSparse t s (EVar ext (d1 t) (IS IZ)) (EVar ext (applySparse s (d2 t)) IZ)) epr e
+expandSparse (STScal STF32) SpScal _ e = e
+expandSparse (STScal STF64) SpScal _ e = e
+expandSparse (STAccum{}) _ _ _ = error "accumulators not allowed in source program"
+
+subenvPlus :: SBool req1 -> SBool req2
+ -> SList SMTy env
+ -> SubenvS env env1 -> SubenvS env env2
+ -> (forall env3. SubenvS env env3
+ -> Injection req1 (Tup env1) (Tup env3)
+ -> Injection req2 (Tup env2) (Tup env3)
+ -> (forall e. Ex e (Tup env1) -> Ex e (Tup env2) -> Ex e (Tup env3))
+ -> r)
+ -> r
+-- don't destroy effects!
+subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\a b -> use a $ use b $ ENil ext)
+
+subenvPlus req1 req2 (SCons _ env) (SENo sub1) (SENo sub2) k =
+ subenvPlus req1 req2 env sub1 sub2 $ \sub3 s31 s32 pl ->
+ k (SENo sub3) s31 s32 pl
+
+subenvPlus req1 SF (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k =
+ subenvPlus req1 SF env sub1 sub2 $ \sub3 minj13 _ pl ->
+ k (SEYes sp1 sub3)
+ (withInj minj13 $ \inj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) e1b)
+ Noinj
+ (\e1 e2 ->
+ ELet ext e1 $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
+ (weakenExpr WSink e2))
+ (ESnd ext (EVar ext (typeOf e1) IZ)))
+subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k
+ | Just zero1 <- cheapZero (applySparse sp1 t) =
+ subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl ->
+ k (SEYes sp1 sub3)
+ (withInj minj13 $ \inj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) e1b)
+ (Inj $ \e2 -> EPair ext (inj23 e2) zero1)
+ (\e1 e2 ->
+ ELet ext e1 $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
+ (weakenExpr WSink e2))
+ (ESnd ext (EVar ext (typeOf e1) IZ)))
+ | otherwise =
+ subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl ->
+ k (SEYes (SpSparse sp1) sub3)
+ (withInj minj13 $ \inj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) (EJust ext e1b))
+ (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t))))
+ (\e1 e2 ->
+ ELet ext e1 $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
+ (weakenExpr WSink e2))
+ (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ))))
+
+subenvPlus req1 req2 (SCons t env) sub1@SENo{} sub2@SEYes{} k =
+ subenvPlus req2 req1 (SCons t env) sub2 sub1 $ \sub3 minj23 minj13 pl ->
+ k sub3 minj13 minj23 (flip pl)
+
+subenvPlus req1 req2 (SCons t env) (SEYes sp1 sub1) (SEYes sp2 sub2) k =
+ subenvPlus req1 req2 env sub1 sub2 $ \sub3 minj13 minj23 pl ->
+ sparsePlusS req1 req2 t sp1 sp2 $ \sp3 mTinj13 mTinj23 plus ->
+ k (SEYes sp3 sub3)
+ (withInj2 minj13 mTinj13 $ \inj13 tinj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) (tinj13 e1b))
+ (withInj2 minj23 mTinj23 $ \inj23 tinj23 ->
+ \e2 -> eunPair e2 $ \_ e2a e2b ->
+ EPair ext (inj23 e2a) (tinj23 e2b))
+ (\e1 e2 ->
+ ELet ext e1 $
+ ELet ext (weakenExpr WSink e2) $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ)))
+ (EFst ext (EVar ext (typeOf e2) IZ)))
+ (plus
+ (ESnd ext (EVar ext (typeOf e1) (IS IZ)))
+ (ESnd ext (EVar ext (typeOf e2) IZ))))
+
+expandSubenvZeros :: D1E env0 :> env -> SList STy env0 -> SubenvS (D2E env0) contribs
+ -> Ex env (Tup contribs) -> Ex env (Tup (D2E env0))
+expandSubenvZeros _ SNil SETop _ = ENil ext
+expandSubenvZeros w (SCons t ts) (SEYes sp sub) e =
+ eunPair e $ \w1 e1 e2 ->
+ EPair ext
+ (expandSubenvZeros (w1 .> WPop w) ts sub e1)
+ (expandSparse t sp (EVar ext (d1 t) (w1 .> w @> IZ)) e2)
+expandSubenvZeros w (SCons t ts) (SENo sub) e =
+ EPair ext
+ (expandSubenvZeros (WPop w) ts sub e)
+ (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))
+
+
+--------------------------------- ACCUMULATORS ---------------------------------
+
+fromArrayValId :: Maybe (ValId t) -> Maybe Int
+fromArrayValId (Just (VIArr i _)) = Just i
+fromArrayValId _ = Nothing
+
+accumPromote :: forall dt env sto proxy r.
+ proxy dt
+ -> Descr env sto
+ -> (forall stoRepl envPro.
+ (Select env stoRepl "merge" ~ '[])
+ => Descr env stoRepl
+ -- ^ A revised environment description that switches
+ -- arrays (used in the OccEnv) that are currently on
+ -- "merge" storage, to "accum" storage.
+ -> SList STy envPro
+ -- ^ New entries on top of the original dual environment,
+ -- that house the accumulators for the promoted arrays in
+ -- the original environment.
+ -> Subenv (Select env sto "merge") envPro
+ -- ^ The promoted entries were merge entries in the
+ -- original environment.
+ -> Subenv (D2AcE (Select env stoRepl "accum")) (D2AcE (Select env sto "accum"))
+ -- ^ All entries that were accumulators are still
+ -- accumulators.
+ -> VarMap Int (D2AcE (Select env stoRepl "accum"))
+ -- ^ Accumulator map for _only_ the the newly allocated
+ -- accumulators.
+ -> (forall shbinds.
+ SList STy shbinds
+ -> (dt : Append shbinds (D2AcE (Select env stoRepl "accum")))
+ :> Append (D2AcE envPro) (dt : Append shbinds (D2AcE (Select env sto "accum"))))
+ -- ^ A weakening that converts a computation in the
+ -- revised environment to one in the original environment
+ -- extended with some accumulators.
+ -> r)
+ -> r
+accumPromote _ DTop k = k DTop SNil SETop SETop VarMap.empty (\_ -> WId)
+accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
+ -- Accumulators are left as-is
+ SAccum ->
+ accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf ->
+ k (storepl `DPush` (t, vid, SAccum))
+ envpro
+ prosub
+ (SEYesR accrevsub)
+ (VarMap.sink1 accumMap)
+ (\shbinds ->
+ autoWeak (#pro (d2ace envpro) &. #d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr)))
+ (#acc :++: (#pro :++: #d :++: #shb :++: #tl))
+ (#pro :++: #d :++: #shb :++: #acc :++: #tl)
+ .> WCopy (wf shbinds)
+ .> autoWeak (#d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl)))
+ (#d :++: #shb :++: #acc :++: #tl)
+ (#acc :++: (#d :++: #shb :++: #tl)))
+
+ SMerge -> case t of
+ -- Discrete values are left as-is
+ _ | isDiscrete t ->
+ accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap' wf ->
+ k (storepl `DPush` (t, vid, SDiscr))
+ envpro
+ (SENo prosub)
+ accrevsub
+ accumMap'
+ wf
+
+ -- Values with "merge" storage are promoted to an accumulator in envPro
+ _ ->
+ accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf ->
+ k (storepl `DPush` (t, vid, SAccum))
+ (t `SCons` envpro)
+ (SEYesR prosub)
+ (SENo accrevsub)
+ (let accumMap' = VarMap.sink1 accumMap
+ in case fromArrayValId vid of
+ Just i -> VarMap.insert i (STAccum (d2M t)) IZ accumMap'
+ Nothing -> accumMap')
+ (\(shbinds :: SList _ shbinds) ->
+ let shbindsC = slistMap (\_ -> Const ()) shbinds
+ in
+ -- wf:
+ -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
+ -- WCopy wf:
+ -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
+ -- WPICK: ^ THESE TWO ||
+ -- goal: | ARE EQUAL ||
+ -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
+ WCopy (wf shbinds)
+ .> WPick @(TAccum (D2 t)) @(dt : shbinds) (Const () `SCons` shbindsC)
+ (WId @(D2AcE (Select env1 stoRepl "accum"))))
+
+ -- Discrete values are left as-is, nothing to do
+ SDiscr ->
+ accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf ->
+ k (storepl `DPush` (t, vid, SDiscr))
+ envpro
+ prosub
+ accrevsub
+ accumMap
+ wf
+ where
+ isDiscrete :: STy t' -> Bool
+ isDiscrete = \case
+ STNil -> True
+ STPair a b -> isDiscrete a && isDiscrete b
+ STEither a b -> isDiscrete a && isDiscrete b
+ STLEither a b -> isDiscrete a && isDiscrete b
+ STMaybe a -> isDiscrete a
+ STArr _ a -> isDiscrete a
+ STScal st -> case st of
+ STI32 -> True
+ STI64 -> True
+ STF32 -> False
+ STF64 -> False
+ STBool -> True
+ STAccum{} -> False
+
+
+---------------------------- RETURN TRIPLE FROM CHAD ---------------------------
+
+data Ret env0 sto sd t =
+ forall shbinds tapebinds contribs.
+ Ret (Bindings Ex (D1E env0) shbinds) -- shared binds
+ (Subenv shbinds tapebinds)
+ (Ex (Append shbinds (D1E env0)) (D1 t))
+ (SubenvS (D2E (Select env0 sto "merge")) contribs)
+ (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs))
+deriving instance Show (Ret env0 sto sd t)
+
+type data TyTyPair = MkTyTyPair Ty Ty
+
+data SingleRet env0 sto (pair :: TyTyPair) =
+ forall shbinds tapebinds.
+ SingleRet
+ (Bindings Ex (D1E env0) shbinds) -- shared binds
+ (Subenv shbinds tapebinds)
+ (RetPair env0 sto (D1E env0) shbinds tapebinds pair)
+
+-- pattern Ret1 :: forall env0 sto Bindings Ex (D1E env0) shbinds
+-- -> Subenv shbinds tapebinds
+-- -> Ex (Append shbinds (D1E env0)) (D1 t)
+-- -> SubenvS (D2E (Select env0 sto "merge")) contribs
+-- -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)
+-- -> SingleRet env0 sto (MkTyTyPair sd t)
+-- pattern Ret1 e0 subtape e1 sub e2 = SingleRet e0 subtape (RetPair e1 sub e2)
+-- {-# COMPLETE Ret1 #-}
+
+data RetPair env0 sto env shbinds tapebinds (pair :: TyTyPair) where
+ RetPair :: forall sd t contribs -- existentials
+ env0 sto env shbinds tapebinds. -- universals
+ Ex (Append shbinds env) (D1 t)
+ -> SubenvS (D2E (Select env0 sto "merge")) contribs
+ -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)
+ -> RetPair env0 sto env shbinds tapebinds (MkTyTyPair sd t)
+deriving instance Show (RetPair env0 sto env shbinds tapebinds pair)
+
+data Rets env0 sto env list =
+ forall shbinds tapebinds.
+ Rets (Bindings Ex env shbinds)
+ (Subenv shbinds tapebinds)
+ (SList (RetPair env0 sto env shbinds tapebinds) list)
+deriving instance Show (Rets env0 sto env list)
+
+toSingleRet :: Ret env0 sto sd t -> SingleRet env0 sto (MkTyTyPair sd t)
+toSingleRet (Ret e0 subtape e1 sub e2) = SingleRet e0 subtape (RetPair e1 sub e2)
+
+weakenRetPair :: SList STy shbinds -> env :> env'
+ -> RetPair env0 sto env shbinds tapebinds pair -> RetPair env0 sto env' shbinds tapebinds pair
+weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2
+
+weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list
+weakenRets w (Rets binds tapesub list) =
+ let (binds', _) = weakenBindingsE w binds
+ in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list)
+
+rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto pair f.
+ Descr env0 sto
+ -> SList f b1 -> SList f b2
+ -> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2
+ -> RetPair env0 sto (Append b1 env) b2 tapebinds2 pair
+ -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) pair
+rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair @sd e1 sub e2)
+ | Refl <- lemAppendAssoc @b2 @b1 @env =
+ RetPair e1 sub
+ (weakenExpr (autoWeak
+ (#d (auto1 @sd)
+ &. #t2 (subList b2 subtape2)
+ &. #t1 (subList b1 subtape1)
+ &. #tl (d2ace (select SAccum descr)))
+ (#d :++: (#t2 :++: #tl))
+ (#d :++: ((#t2 :++: #t1) :++: #tl)))
+ e2)
+
+retConcat :: forall env0 sto list. Descr env0 sto -> SList (SingleRet env0 sto) list -> Rets env0 sto (D1E env0) list
+retConcat _ SNil = Rets BTop SETop SNil
+retConcat descr (SCons (SingleRet (e0 :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) (RetPair e1 sub e2)) list)
+ | Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs
+ <- weakenRets (sinkWithBindings e0) (retConcat descr list)
+ , Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0)
+ , Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum"))
+ = Rets (bconcat e0 binds)
+ (subenvConcat subtape subtape2)
+ (SCons (RetPair (weakenExpr (sinkWithBindings binds) e1)
+ sub
+ (weakenExpr (WCopy (sinkWithSubenv subtape2)) e2))
+ (slistMap (rebaseRetPair descr (bindingsBinds e0) (bindingsBinds binds)
+ subtape subtape2)
+ pairs))
+
+freezeRet :: Descr env sto
+ -> Ret env sto (D2 t) t
+ -> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge"))))
+freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ _ t) =
+ let (e0', wInsertD2Ac) = weakenBindingsE (WSink .> wSinks (d2ace (select SAccum descr))) e0
+ e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2
+ tContribs = tTup (slistMap fromSMTy (subList (d2eM (select SMerge descr)) sub))
+ library = #d (auto1 @(D2 t))
+ &. #tape (subList (bindingsBinds e0) subtape)
+ &. #shbinds (bindingsBinds e0)
+ &. #d2ace (d2ace (select SAccum descr))
+ &. #tl (desD1E descr)
+ &. #contribs (SCons tContribs SNil)
+ in letBinds e0' $
+ EPair ext
+ (weakenExpr wInsertD2Ac e1)
+ (ELet ext (weakenExpr (autoWeak library
+ (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl)
+ (#shbinds :++: #d :++: #d2ace :++: #tl))
+ e2') $
+ expandSubenvZeros
+ (autoWeak library #tl (#contribs :++: #shbinds :++: #d :++: #d2ace :++: #tl)
+ .> wUndoSubenv (subenvD1E (selectSub SMerge descr)))
+ (select SMerge descr) sub (EVar ext tContribs IZ))
+
+
+---------------------------- THE CHAD TRANSFORMATION ---------------------------
+
+drev :: forall env sto sd t.
+ (?config :: CHADConfig)
+ => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
+ -> Sparse (D2 t) sd
+ -> Expr ValId env t -> Ret env sto sd t
+drev des _ sd | isAbsent sd =
+ \e ->
+ Ret BTop
+ SETop
+ (drevPrimal des e)
+ (subenvNone (d2e (select SMerge des)))
+ (ENil ext)
+drev _ _ SpAbsent = error "Absent should be isAbsent"
+
+drev des accumMap (SpSparse sd) =
+ \e ->
+ case drev des accumMap sd e of { Ret e0 subtape e1 sub e2 ->
+ subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ ->
+ Ret e0
+ subtape
+ e1
+ sub'
+ (emaybe (EVar ext (STMaybe (applySparse sd (d2 (typeOf e)))) IZ)
+ (inj2 (ENil ext))
+ (inj1 (weakenExpr (WCopy WSink) e2)))
+ }
+
+drev des accumMap sd = \case
+ EVar _ t i ->
+ case conv2Idx des i of
+ Idx2Ac accI ->
+ Ret BTop
+ SETop
+ (EVar ext (d1 t) (conv1Idx i))
+ (subenvNone (d2e (select SMerge des)))
+ (let ty = applySparse sd (d2M t)
+ in EAccum ext (d2M t) SAPHere (ENil ext) sd (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI)))
+
+ Idx2Me tupI ->
+ Ret BTop
+ SETop
+ (EVar ext (d1 t) (conv1Idx i))
+ (subenvOnehot (d2e (select SMerge des)) tupI sd)
+ (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t)) IZ))
+
+ Idx2Di _ ->
+ Ret BTop
+ SETop
+ (EVar ext (d1 t) (conv1Idx i))
+ (subenvNone (d2e (select SMerge des)))
+ (ENil ext)
+
+ ELet _ (rhs :: Expr _ _ a) body
+ | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && typeHasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge
+ , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body
+ , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) rhs1 subRHS rhs2 <- drev des accumMap sdBody rhs
+ , let (body0', wbody0') = weakenBindingsE (WCopy (sinkWithBindings rhs0)) body0
+ , Refl <- lemAppendAssoc @body_shbinds @'[D1 a] @rhs_shbinds
+ , Refl <- lemAppendAssoc @body_shbinds @(D1 a : rhs_shbinds) @(D1E env)
+ , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum"))
+ ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subRHS subBody $ \subBoth _ _ plus_RHS_Body ->
+ let bodyResType = STPair (contribTupTy des subBody) (applySparse sdBody (d2 (typeOf rhs))) in
+ Ret (bconcat (rhs0 `bpush` rhs1) body0')
+ (subenvConcat subtapeRHS subtapeBody)
+ (weakenExpr wbody0' body1)
+ subBoth
+ (ELet ext (weakenExpr (autoWeak (#d (auto1 @sd)
+ &. #body (subList (bindingsBinds body0 `sappend` SCons (d1 (typeOf rhs)) SNil) subtapeBody)
+ &. #rhs (subList (bindingsBinds rhs0) subtapeRHS)
+ &. #tl (d2ace (select SAccum des)))
+ (#d :++: #body :++: #tl)
+ (#d :++: (#body :++: #rhs) :++: #tl))
+ body2) $
+ ELet ext
+ (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $
+ weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $
+ plus_RHS_Body
+ (EVar ext (contribTupTy des subRHS) IZ)
+ (EFst ext (EVar ext bodyResType (IS IZ))))
+
+ EPair _ a b
+ | SpPair sd1 sd2 <- sd
+ , Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)
+ <- retConcat des $ toSingleRet (drev des accumMap sd1 a) `SCons` toSingleRet (drev des accumMap sd2 b) `SCons` SNil
+ , let dt = STPair (applySparse sd1 (d2 (typeOf a))) (applySparse sd2 (d2 (typeOf b))) ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B ->
+ Ret binds
+ subtape
+ (EPair ext a1 b1)
+ subBoth
+ (ELet ext (ELet ext (EFst ext (EVar ext dt IZ))
+ (weakenExpr (WCopy WSink) a2)) $
+ ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ)))
+ (weakenExpr (WCopy (WSink .> WSink)) b2)) $
+ plus_A_B
+ (EVar ext (contribTupTy des subA) (IS IZ))
+ (EVar ext (contribTupTy des subB) IZ))
+
+ EFst _ e
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair sd SpAbsent) e
+ , STPair t1 _ <- typeOf e ->
+ Ret e0
+ subtape
+ (EFst ext e1)
+ sub
+ (ELet ext (EPair ext (EVar ext (applySparse sd (d2 t1)) IZ) (ENil ext)) $
+ weakenExpr (WCopy WSink) e2)
+
+ ESnd _ e
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair SpAbsent sd) e
+ , STPair _ t2 <- typeOf e ->
+ Ret e0
+ subtape
+ (ESnd ext e1)
+ sub
+ (ELet ext (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t2)) IZ)) $
+ weakenExpr (WCopy WSink) e2)
+
+ -- Don't need to handle ENil, because its cotangent is always absent!
+ -- ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (d2e (select SMerge des))) (ENil ext)
+
+ EInl _ t2 e
+ | SpLEither sd1 sd2 <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap sd1 e ->
+ subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ ->
+ Ret e0
+ subtape
+ (EInl ext (d1 t2) e1)
+ sub'
+ (ELCase ext
+ (EVar ext (STLEither (applySparse sd1 (d2 (typeOf e))) (applySparse sd2 (d2 t2))) IZ)
+ (inj2 $ ENil ext)
+ (inj1 $ weakenExpr (WCopy WSink) e2)
+ (EError ext (contribTupTy des sub') "inl<-dinr"))
+
+ EInr _ t1 e
+ | SpLEither sd1 sd2 <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap sd2 e ->
+ subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ ->
+ Ret e0
+ subtape
+ (EInr ext (d1 t1) e1)
+ sub'
+ (ELCase ext
+ (EVar ext (STLEither (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 (typeOf e)))) IZ)
+ (inj2 $ ENil ext)
+ (EError ext (contribTupTy des sub') "inr<-dinl")
+ (inj1 $ weakenExpr (WCopy WSink) e2))
+
+ ECase _ e (a :: Expr _ _ t) b
+ | STEither (t1 :: STy a) (t2 :: STy b) <- typeOf e
+ , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && typeHasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge
+ , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && typeHasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge
+ , let (bindids1, bindids2) = validSplitEither (extOf e)
+ , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA sd1 a2
+ <- drevScoped des accumMap t1 storage1 bindids1 sd a
+ , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB sd2 b2
+ <- drevScoped des accumMap t2 storage2 bindids2 sd b
+ , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap (SpLEither sd1 sd2) e
+ , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
+ , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
+ , let subtapeListA = subList (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA
+ , let subtapeListB = subList (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB
+ , let tapeA = tapeTy subtapeListA
+ , let tapeB = tapeTy subtapeListB
+ , let collectA = bindingsCollectTape @_ @_ @(Append rhs_a_binds (D1 a : Append e_binds (D1E env)))
+ (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA
+ , let collectB = bindingsCollectTape @_ @_ @(Append rhs_b_binds (D1 b : Append e_binds (D1E env)))
+ (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB
+ , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB)
+ , let (a0', wa0') = weakenBindingsE (WCopy (sinkWithBindings e0)) a0
+ , let (b0', wb0') = weakenBindingsE (WCopy (sinkWithBindings e0)) b0
+ , Refl <- lemAppendNil @(Append rhs_a_binds '[D1 a])
+ , Refl <- lemAppendNil @(Append rhs_b_binds '[D1 b])
+ , Refl <- lemAppendAssoc @rhs_a_binds @'[D1 a] @(D1E env)
+ , Refl <- lemAppendAssoc @rhs_b_binds @'[D1 b] @(D1E env)
+ , let wa0'' = wa0' .> wCopies (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) (WClosed @(D1E env))
+ , let wb0'' = wb0' .> wCopies (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) (WClosed @(D1E env))
+ ->
+ subenvPlus ST ST (d2eM (select SMerge des)) subA subB $ \subAB (Inj sAB_A) (Inj sAB_B) _ ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subAB subE $ \subOut _ _ plus_AB_E ->
+ Ret (e0 `bpush` ECase ext e1
+ (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0''))))
+ (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0'')))))
+ (SEYesR subtapeE)
+ (EFst ext (EVar ext tPrimal IZ))
+ subOut
+ (elet
+ (ECase ext (ESnd ext (EVar ext tPrimal (IS IZ)))
+ (let (rebinds, prerebinds) = reconstructBindings subtapeListA
+ in letBinds (rebinds IZ) $
+ ELet ext
+ (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_a_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListA prerebinds) @> IS IZ)) $
+ elet
+ (weakenExpr (autoWeak (#d (auto1 @sd)
+ &. #ta0 subtapeListA
+ &. #prea0 prerebinds
+ &. #recon (tapeA `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil)
+ &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE)
+ &. #tl (d2ace (select SAccum des)))
+ (#d :++: #ta0 :++: #tl)
+ (#d :++: (#ta0 :++: #prea0) :++: #recon :++: #binds :++: #tl))
+ a2) $
+ EPair ext (sAB_A $ EFst ext (evar IZ))
+ (ELInl ext (applySparse sd2 (d2 t2)) (ESnd ext (evar IZ))))
+ (let (rebinds, prerebinds) = reconstructBindings subtapeListB
+ in letBinds (rebinds IZ) $
+ ELet ext
+ (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_b_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListB prerebinds) @> IS IZ)) $
+ elet
+ (weakenExpr (autoWeak (#d (auto1 @sd)
+ &. #tb0 subtapeListB
+ &. #preb0 prerebinds
+ &. #recon (tapeB `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil)
+ &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE)
+ &. #tl (d2ace (select SAccum des)))
+ (#d :++: #tb0 :++: #tl)
+ (#d :++: (#tb0 :++: #preb0) :++: #recon :++: #binds :++: #tl))
+ b2) $
+ EPair ext (sAB_B $ EFst ext (evar IZ))
+ (ELInr ext (applySparse sd1 (d2 t1)) (ESnd ext (evar IZ))))) $
+ plus_AB_E
+ (EFst ext (evar IZ))
+ (ELet ext (ESnd ext (evar IZ)) $
+ weakenExpr (WCopy (wSinks' @[_,_,_])) e2))
+
+ EConst _ t val ->
+ Ret BTop
+ SETop
+ (EConst ext t val)
+ (subenvNone (d2e (select SMerge des)))
+ (ENil ext)
+
+ EOp _ op e
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap (spDense (d2M (opt1 op))) e ->
+ case d2op op of
+ Linear d2opfun ->
+ Ret e0
+ subtape
+ (d1op op e1)
+ sub
+ (ELet ext (d2opfun (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ)))
+ (weakenExpr (WCopy WSink) e2))
+ Nonlinear d2opfun ->
+ Ret (e0 `bpush` e1)
+ (SEYesR subtape)
+ (d1op op $ EVar ext (d1 (typeOf e)) IZ)
+ sub
+ (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ))
+ (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ)))
+ (weakenExpr (WCopy (wSinks' @[_,_])) e2))
+
+ ECustom _ _ tb _ srce pr du a b
+ -- allowed to ignore a2 because 'a' is the part of the input that is inactive
+ | Ret b0 bsubtape b1 bsub b2 <- drev des accumMap (spDense (d2M tb)) b ->
+ case isDense (d2M (typeOf srce)) sd of
+ Just Refl ->
+ Ret (b0 `bpush` weakenExpr (sinkWithBindings b0) (drevPrimal des a)
+ `bpush` weakenExpr WSink b1
+ `bpush` weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)
+ `bpush` ESnd ext (EVar ext (typeOf pr) IZ))
+ (SEYesR (SENo (SENo (SENo bsubtape))))
+ (EFst ext (EVar ext (typeOf pr) (IS IZ)))
+ bsub
+ (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $
+ weakenExpr (WCopy (WSink .> WSink)) b2)
+
+ Nothing ->
+ Ret (b0 `bpush` weakenExpr (sinkWithBindings b0) (drevPrimal des a)
+ `bpush` weakenExpr WSink b1
+ `bpush` weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr))
+ (SEYesR (SENo (SENo bsubtape)))
+ (EFst ext (EVar ext (typeOf pr) IZ))
+ bsub
+ (ELet ext (ESnd ext (EVar ext (typeOf pr) (IS IZ))) $ -- tape
+ ELet ext (expandSparse (typeOf srce) sd -- expanded incoming cotangent
+ (EFst ext (EVar ext (typeOf pr) (IS (IS IZ))))
+ (EVar ext (applySparse sd (d2 (typeOf srce))) (IS IZ))) $
+ ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) b2)
+
+ ERecompute _ e ->
+ deleteUnused (descrList des) (occCountAll e) $ \usedSub ->
+ let smallE = unsafeWeakenWithSubenv usedSub e in
+ subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
+ case drev usedDes (VarMap.subMap subAccumUsed accumMap) sd smallE of { Ret e0 subtape _ sub e2 ->
+ let subMergeUsed' = subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E subMergeUsed) in
+ Ret (collectBindings (desD1E des) subD1eUsed)
+ (subenvAll (desD1E usedDes))
+ (weakenExpr (wSinks (desD1E usedDes)) $ drevPrimal des e)
+ (subenvCompose subMergeUsed' sub)
+ (letBinds (fst (weakenBindingsE (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $
+ weakenExpr
+ (autoWeak (#d (auto1 @sd)
+ &. #shbinds (bindingsBinds e0)
+ &. #tape (subList (bindingsBinds e0) subtape)
+ &. #d1env (desD1E usedDes)
+ &. #tl' (d2ace (select SAccum usedDes))
+ &. #tl (d2ace (select SAccum des)))
+ (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: LPreW #tl' #tl (wUndoSubenv subAccumUsed))
+ (#shbinds :++: #d :++: #d1env :++: #tl))
+ e2)
+ }
+
+ EError _ t s ->
+ Ret BTop
+ SETop
+ (EError ext (d1 t) s)
+ (subenvNone (d2e (select SMerge des)))
+ (ENil ext)
+
+ EConstArr _ n t val ->
+ Ret BTop
+ SETop
+ (EConstArr ext n t val)
+ (subenvNone (d2e (select SMerge des)))
+ (ENil ext)
+
+ EBuild _ (ndim :: SNat ndim) she (ef :: Expr _ _ eltty)
+ | SpArr @_ @sdElt sdElt <- sd
+ , let eltty = typeOf ef
+ , shty :: STy shty <- tTup (sreplicate ndim tIx)
+ , Refl <- indexTupD1Id ndim ->
+ drevLambda des accumMap (shty, SDiscr) sdElt ef $ \(provars :: SList _ envPro) esub proPrimalBinds e0 e1 (e1tape :: Ex _ e_tape) _ wrapAccum e2 ->
+ let library = #ix (shty `SCons` SNil)
+ &. #e0 (bindingsBinds e0)
+ &. #propr (d1e provars)
+ &. #d1env (desD1E des)
+ &. #d (auto1 @sdElt)
+ &. #tape (auto1 @e_tape)
+ &. #pro (d2ace provars)
+ &. #d2acEnv (d2ace (select SAccum des))
+ &. #darr (auto1 @(TArr ndim sdElt))
+ &. #tapearr (auto1 @(TArr ndim e_tape)) in
+ Ret (proPrimalBinds
+ `bpush` weakenExpr (wSinks (d1e provars))
+ (EBuild ext ndim
+ (drevPrimal des she)
+ (letBinds e0 $
+ EPair ext e1 e1tape))
+ `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) IZ))
+ (SEYesR (SENo (subenvAll (d1e provars))))
+ (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) (IS IZ)))
+ (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) esub)
+ (let sinkOverEnvPro = wSinks @(sd : TArr ndim e_tape : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace provars) in
+ ESnd ext $
+ wrapAccum (WSink .> WSink .> wRaiseAbove (d1e provars) (d2ace (select SAccum des))) $
+ EBuild ext ndim (EShape ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (sinkOverEnvPro @> IZ))) $
+ -- the cotangent for this element
+ ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ))
+ (EVar ext shty IZ)) $
+ -- the tape for this element
+ ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
+ (EVar ext shty (IS IZ))) $
+ weakenExpr (autoWeak library (#tape :++: #d :++: #pro :++: #d2acEnv)
+ (#tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv))
+ e2)
+
+ EMap _ ef (earr :: Expr _ _ (TArr n a))
+ | SpArr sdElt <- sd
+ , let STArr ndim t1 = typeOf earr
+ t2 = typeOf ef ->
+ drevLambda des accumMap (t1, SMerge) sdElt ef $ \provars efsub proPrimalBinds ef0 ef1 ef1tape spEf wrapAccum ef2 ->
+ case drev des accumMap (SpArr spEf) earr of { Ret ea0 easubtape ea1 easub ea2 ->
+ let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings ea0) proPrimalBinds
+ ttape = typeOf ef1tape
+ library = #d1env (desD1E des)
+ &. #a0 (bindingsBinds ea0)
+ &. #atapebinds (subList (bindingsBinds ea0) easubtape)
+ &. #propr (d1e provars)
+ &. #x (d1 t1 `SCons` SNil)
+ &. #parr (STArr ndim (d1 t1) `SCons` SNil)
+ &. #tapearr (STArr ndim ttape `SCons` SNil)
+ &. #darr (STArr ndim (applySparse sdElt (d2 t2)) `SCons` SNil)
+ &. #dy (applySparse sdElt (d2 t2) `SCons` SNil)
+ &. #tape (ttape `SCons` SNil)
+ &. #dytape (STPair (applySparse sdElt (d2 t2)) ttape `SCons` SNil)
+ &. #d2acEnv (d2ace (select SAccum des))
+ &. #pro (d2ace provars)
+ in
+ subenvPlus SF SF (d2eM (select SMerge des)) (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) easub $ \subfa _ _ plus_f_a ->
+ Ret (bconcat ea0 proPrimalBinds'
+ `bpush` weakenExpr (autoWeak library (#a0 :++: #d1env) ((#propr :++: #a0) :++: #d1env)) ea1
+ `bpush` emap (weakenExpr (autoWeak library (#x :++: #d1env) (#x :++: #parr :++: (#propr :++: #a0) :++: #d1env))
+ (letBinds ef0 $
+ EPair ext ef1 ef1tape))
+ (EVar ext (STArr ndim (d1 t1)) IZ)
+ `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 t2) ttape)) IZ))
+ (SEYesR (SENo (SENo (subenvConcat easubtape (subenvAll (d1e provars))))))
+ (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 t2) ttape)) (IS IZ)))
+ subfa
+ (let layout = #darr :++: #tapearr :++: (#propr :++: #atapebinds) :++: #d2acEnv in
+ elet
+ (wrapAccum (autoWeak library #propr layout) $
+ emap (elet (EFst ext (EVar ext (STPair (applySparse sdElt (d2 t2)) ttape) IZ)) $
+ elet (ESnd ext (EVar ext (STPair (applySparse sdElt (d2 t2)) ttape) (IS IZ))) $
+ weakenExpr (autoWeak library (#tape :++: #dy :++: #pro :++: #d2acEnv)
+ (#tape :++: #dy :++: #dytape :++: #pro :++: layout))
+ ef2)
+ (ezip (EVar ext (STArr ndim (applySparse sdElt (d2 t2))) (autoWeak library #darr (#pro :++: layout) @> IZ))
+ (EVar ext (STArr ndim ttape) (autoWeak library #tapearr (#pro :++: layout) @> IZ)))) $
+ plus_f_a
+ (ESnd ext (evar IZ))
+ (weakenExpr (WCopy (autoWeak library (#atapebinds :++: #d2acEnv) layout))
+ (subst0 (EFst ext (EVar ext (STPair (STArr ndim (typeOf ef2)) (tTup (d2e provars))) IZ))
+ ea2)))
+ }
+
+ EFold1Inner _ commut origef ex₀ earr
+ | SpArr @_ @sdElt sdElt <- sd
+ , STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr
+ , Rets bindsx₀a subtapex₀a (RetPair ex₀1 subx₀ ex₀2 `SCons` RetPair ea1 suba ea2 `SCons` SNil)
+ <- retConcat des $ toSingleRet (drev des accumMap (spDense (d2M eltty)) ex₀) `SCons` toSingleRet (drev des accumMap (spDense (SMTArr (SS ndim) (d2M eltty))) earr) `SCons` SNil ->
+ drevLambda des accumMap (STPair eltty eltty, SMerge) (spDense (d2M eltty)) origef $ \(provars :: SList _ envPro) efsub proPrimalBinds ef0 ef1 (ef1tape :: Ex _ ef_tape) spEf wrapAccum ef2 ->
+ let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings bindsx₀a) proPrimalBinds in
+ let bogEltTy = STPair (STPair (d1 eltty) (d1 eltty)) (typeOf ef1tape)
+ bogTy = STArr (SS ndim) bogEltTy
+ primalTy = STPair (STArr ndim (d1 eltty)) bogTy
+ library = #xy (STPair (d1 eltty) (d1 eltty) `SCons` SNil)
+ &. #parr (auto1 @(TArr (S n) (D1 elt)))
+ &. #px₀ (auto1 @(D1 elt))
+ &. #px (auto1 @(D1 elt))
+ &. #pzi (auto1 @(ZeroInfo (D2 elt)))
+ &. #primal (primalTy `SCons` SNil)
+ &. #darr (auto1 @(TArr n sdElt))
+ &. #d (auto1 @(D2 elt))
+ &. #x₀abinds (bindingsBinds bindsx₀a)
+ &. #fbinds (bindingsBinds ef0)
+ &. #x₀atapebinds (subList (bindingsBinds bindsx₀a) subtapex₀a)
+ &. #ftape (auto1 @ef_tape)
+ &. #bogelt (bogEltTy `SCons` SNil)
+ &. #propr (d1e provars)
+ &. #d1env (desD1E des)
+ &. #d2acEnv (d2ace (select SAccum des))
+ &. #d2acPro (d2ace provars)
+ &. #foldd2res (auto1 @(TPair (TPair (D2 elt) (TArr (S n) (D2 elt))) (Tup (D2E envPro))))
+ wOverPrimalBindings = autoWeak library (#x₀abinds :++: #d1env) ((#propr :++: #x₀abinds) :++: #d1env) in
+ subenvPlus SF SF (d2eM (select SMerge des)) subx₀ suba $ \subx₀a _ _ plus_x₀_a ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) $ \subx₀af _ _ plus_x₀a_f ->
+ Ret (bconcat bindsx₀a proPrimalBinds'
+ `bpush` weakenExpr wOverPrimalBindings ex₀1
+ `bpush` d2zeroInfo eltty (EVar ext (d1 eltty) IZ)
+ `bpush` weakenExpr (WSink .> WSink .> wOverPrimalBindings) ea1
+ `bpush` EFold1InnerD1 ext commut
+ (let layout = #xy :++: #parr :++: #pzi :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env in
+ weakenExpr (autoWeak library (#xy :++: #d1env) layout)
+ (letBinds ef0 $
+ EPair ext -- (out, ((in1, in2), tape)); the "additional stores" are ((in1, in2), tape)
+ ef1
+ (EPair ext
+ (EVar ext (STPair (d1 eltty) (d1 eltty)) (autoWeak library #xy (#fbinds :++: #xy :++: #d1env) @> IZ))
+ ef1tape)))
+ (EVar ext (d1 eltty) (IS (IS IZ)))
+ (EVar ext (STArr (SS ndim) (d1 eltty)) IZ))
+ (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e provars)))))))
+ (EFst ext (EVar ext primalTy IZ))
+ subx₀af
+ (let layout1 = #darr :++: #primal :++: #parr :++: #pzi :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in
+ elet
+ (wrapAccum (autoWeak library #propr layout1) $
+ let layout2 = #d2acPro :++: layout1 in
+ EFold1InnerD2 ext commut
+ (elet (ESnd ext (EVar ext bogEltTy (IS IZ))) $
+ let layout3 = #ftape :++: #d :++: #bogelt :++: layout2 in
+ expandSparse (STPair eltty eltty) spEf (EFst ext (EVar ext bogEltTy (IS (IS IZ)))) $
+ weakenExpr (autoWeak library (#ftape :++: #d :++: #d2acPro :++: #d2acEnv) layout3) ef2)
+ (ESnd ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))
+ (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ)))
+ (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ))
+ (EFst ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))))) $
+ plus_x₀a_f
+ (plus_x₀_a
+ (elet (EIdx0 ext
+ (EFold1Inner ext Commut
+ (let t = STPair (d2 eltty) (d2 eltty)
+ in EPlus ext (d2M eltty) (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ)))
+ (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (WSink .> autoWeak library #pzi layout1 @> IZ)))
+ (eflatten (EFst ext (EFst ext (evar IZ)))))) $
+ weakenExpr (WCopy (WSink .> autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1))
+ ex₀2)
+ (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $
+ subst0 (ESnd ext (EFst ext (evar IZ))) ea2))
+ (ESnd ext (evar IZ)))
+
+ EUnit _ e
+ | SpArr sdElt <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap sdElt e ->
+ Ret e0
+ subtape
+ (EUnit ext e1)
+ sub
+ (ELet ext (EIdx0 ext (EVar ext (STArr SZ (applySparse sdElt (d2 (typeOf e)))) IZ)) $
+ weakenExpr (WCopy WSink) e2)
+
+ EReplicate1Inner _ en e
+ -- We're allowed to differentiate 'en' as primal-only here because its output is discrete.
+ | SpArr sdElt <- sd
+ , let STArr ndim eltty = typeOf e ->
+ -- This pessimistic sparsity union is because the array might have been empty, in which case we need to generate a zero.
+ sparsePlusS ST ST (d2M eltty) sdElt SpAbsent $ \sdElt' (Inj inj1) (Inj inj2) _ ->
+ case drev des accumMap (SpArr sdElt') e of { Ret binds subtape e1 sub e2 ->
+ Ret binds
+ subtape
+ (EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1)
+ sub
+ (ELet ext (EFold1Inner ext Commut
+ (let t = STPair (applySparse sdElt' (d2 eltty)) (applySparse sdElt' (d2 eltty))
+ in sparsePlus (d2M eltty) sdElt' (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ)))
+ (inj2 (ENil ext))
+ (emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $
+ weakenExpr (WCopy WSink) e2)
+ }
+
+ EIdx0 _ e
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd) e
+ , STArr _ t <- typeOf e ->
+ Ret e0
+ subtape
+ (EIdx0 ext e1)
+ sub
+ (ELet ext (EUnit ext (EVar ext (applySparse sd (d2 t)) IZ)) $
+ weakenExpr (WCopy WSink) e2)
+
+ EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead"
+ {-
+ EIdx1 _ e ei
+ -- We're allowed to ignore ei2 here because the output of 'ei' is discrete.
+ | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)
+ <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil
+ , STArr (SS n) eltty <- typeOf e ->
+ Ret (binds `bpush` e1
+ `bpush` EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ))
+ (SEYesR (SENo subtape))
+ (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ))
+ (weakenExpr (WSink .> WSink) ei1))
+ sub
+ (ELet ext (ebuildUp1 n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ)))
+ (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ)))
+ (EVar ext (STArr n (d2 eltty)) (IS IZ))) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
+ -}
+
+ EIdx _ e ei
+ -- We're allowed to differentiate ei as primal because its output is discrete.
+ | STArr n eltty <- typeOf e
+ , Refl <- indexTupD1Id n
+ , let tIxN = tTup (sreplicate n tIx) ->
+ sparsePlusS ST ST (d2M eltty) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ ->
+ case drev des accumMap (SpArr sd') e of { Ret binds subtape e1 sub e2 ->
+ Ret (binds `bpush` e1
+ `bpush` EShape ext (EVar ext (typeOf e1) IZ)
+ `bpush` weakenExpr (WSink .> WSink .> wSinks (bindingsBinds binds)) (drevPrimal des ei))
+ (SEYesR (SEYesR (SENo subtape)))
+ (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ)))
+ (EVar ext (tTup (sreplicate n tIx)) IZ))
+ sub
+ (ELet ext
+ (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty)))
+ (SAPArrIdx SAPHere)
+ (EPair ext
+ (EPair ext (EVar ext tIxN (IS IZ))
+ (EBuild ext n (EVar ext tIxN (IS (IS IZ))) $
+ makeZeroInfo (applySparse sd' (d2M eltty)) (inj2 (ENil ext))))
+ (ENil ext))
+ (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+ }
+
+ EShape _ e
+ -- Allowed to differentiate e as primal because the output of EShape is
+ -- discrete, hence we'd be passing a zero cotangent to e anyway.
+ | STArr n _ <- typeOf e
+ , Refl <- indexTupD1Id n ->
+ Ret BTop
+ SETop
+ (EShape ext (drevPrimal des e))
+ (subenvNone (d2eM (select SMerge des)))
+ (ENil ext)
+
+ ESum1Inner _ e
+ | SpArr sd' <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e
+ , STArr (SS n) t <- typeOf e ->
+ Ret (e0 `bpush` e1
+ `bpush` EShape ext (EVar ext (STArr (SS n) t) IZ))
+ (SEYesR (SENo subtape))
+ (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ)))
+ sub
+ (ELet ext (EReplicate1Inner ext
+ (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ)))
+ (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
+
+ EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e
+ EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e
+
+ EReshape _ n esh e
+ | SpArr sd' <- sd
+ , STArr orign t <- typeOf e
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e
+ , Refl <- indexTupD1Id n ->
+ Ret (e0 `bpush` e1
+ `bpush` EShape ext (EVar ext (STArr orign (d1 t)) IZ))
+ (SEYesR (SENo subtape))
+ (EReshape ext n (weakenExpr (WSink .> WSink .> wSinks (bindingsBinds e0)) (drevPrimal des esh))
+ (EVar ext (STArr orign (d1 t)) (IS IZ)))
+ sub
+ (elet (EReshape ext orign (EVar ext (tTup (sreplicate orign tIx)) (IS IZ))
+ (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
+
+ EZip _ a b
+ | SpArr sd' <- sd
+ , STArr n t1 <- typeOf a
+ , STArr _ t2 <- typeOf b ->
+ splitSparsePair (STPair (d2 t1) (d2 t2)) sd' $ \sd1 sd2 pairSplitE ->
+ case retConcat des (toSingleRet (drev des accumMap (SpArr sd1) a) `SCons`
+ toSingleRet (drev des accumMap (SpArr sd2) b) `SCons` SNil) of
+ { Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B ->
+ Ret binds
+ subtape
+ (EZip ext a1 b1)
+ subBoth
+ (case pairSplitE of
+ Left Refl ->
+ let t' = STArr n (STPair (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 t2))) in
+ plus_A_B
+ (elet (emap (EFst ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) a2)
+ (elet (emap (ESnd ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) b2)
+ Right f -> f IZ $ \wrapPair pick1 pick2 ->
+ elet (emap (wrapPair (EPair ext pick1 pick2))
+ (EVar ext (applySparse (SpArr sd') (STArr n (STPair (d2 t1) (d2 t2)))) IZ)) $
+ plus_A_B
+ (elet (emap (EFst ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) a2)
+ (elet (emap (ESnd ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) b2))
+ }
+
+ ENothing{} -> err_unsupported "ENothing"
+ EJust{} -> err_unsupported "EJust"
+ EMaybe{} -> err_unsupported "EMaybe"
+ ELNil{} -> err_unsupported "ELNil"
+ ELInl{} -> err_unsupported "ELInl"
+ ELInr{} -> err_unsupported "ELInr"
+ ELCase{} -> err_unsupported "ELCase"
+
+ EWith{} -> err_accum
+ EZero{} -> err_monoid
+ EDeepZero{} -> err_monoid
+ EPlus{} -> err_monoid
+ EOneHot{} -> err_monoid
+
+ EFold1InnerD1{} -> err_targetlang "EFold1InnerD1"
+ EFold1InnerD2{} -> err_targetlang "EFold1InnerD2"
+
+ where
+ err_accum = error "Accumulator operations unsupported in the source program"
+ err_monoid = error "Monoid operations unsupported in the source program"
+ err_unsupported s = error $ "CHAD: unsupported " ++ s
+ err_targetlang s = error $ "CHAD: Target language operation " ++ s ++ " not supported in source program"
+
+ contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs)
+ contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub))
+
+deriv_extremum :: (?config :: CHADConfig, ScalIsNumeric t ~ True)
+ => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t)))
+ -> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
+ -> Sparse (D2s t) sd
+ -> Expr ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t))
+deriv_extremum extremum des accumMap sd e
+ | at@(STArr (SS n) t@(STScal st)) <- typeOf e
+ , let at' = STArr n t
+ , let tIxN = tTup (sreplicate (SS n) tIx) =
+ sparsePlusS ST ST (d2M t) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ ->
+ case drev des accumMap (SpArr sd') e of { Ret e0 subtape e1 sub e2 ->
+ Ret (e0 `bpush` e1
+ `bpush` extremum (EVar ext at IZ))
+ (SEYesR (SEYesR subtape))
+ (EVar ext at' IZ)
+ sub
+ (ELet ext
+ (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $
+ eif (EOp ext (OEq st) (EPair ext
+ (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ))
+ (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ)))))
+ (inj1 $ EIdx ext (EVar ext (STArr n (applySparse sd (d2 t))) (IS IZ)) (EFst ext (EVar ext tIxN IZ)))
+ (inj2 (ENil ext))) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+ }
+
+data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s)
+
+data RetScoped env0 sto a s sd t =
+ forall shbinds tapebinds contribs sa.
+ RetScoped
+ (Bindings Ex (D1E (a : env0)) shbinds) -- shared binds
+ (Subenv (Append shbinds '[D1 a]) tapebinds)
+ (Ex (Append shbinds (D1E (a : env0))) (D1 t))
+ (SubenvS (D2E (Select env0 sto "merge")) contribs)
+ -- ^ merge contributions to the _enclosing_ merge environment
+ (Sparse (D2 a) sa)
+ -- ^ contribution to the argument
+ (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum")))
+ (If (s == "discr") (Tup contribs)
+ (TPair (Tup contribs) sa)))
+ -- ^ the merge contributions, plus the cotangent to the argument
+ -- (if there is any)
+deriving instance Show (RetScoped env0 sto a s sd t)
+
+drevScoped :: forall a s env sto sd t.
+ (?config :: CHADConfig)
+ => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
+ -> STy a -> Storage s -> Maybe (ValId a)
+ -> Sparse (D2 t) sd
+ -> Expr ValId (a : env) t
+ -> RetScoped env sto a s sd t
+drevScoped des accumMap argty argsto argids sd expr = case argsto of
+ SMerge
+ | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr
+ , Refl <- lemAppendNil @tapebinds ->
+ case sub of
+ SEYes sp sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' sp e2
+ SENo sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' SpAbsent (EPair ext e2 (ENil ext))
+
+ SAccum
+ | chcSmartWith ?config
+ , Just (VIArr i _) <- argids
+ , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap
+ , Just Refl <- testEquality foundTy (STAccum (d2M argty))
+ , Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) sd expr
+ , Refl <- lemAppendNil @tapebinds ->
+ -- Our contribution to the binding's cotangent _here_ is zero (absent),
+ -- because we're contributing to an earlier binding of the same value
+ -- instead.
+ RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent $
+ let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in
+ ELet ext (EVar ext (STAccum (d2M argty)) (WSink .> wtapebinds @> idx)) $
+ weakenExpr (autoWeak (#d (auto1 @sd)
+ &. #body (subList (bindingsBinds e0) subtape)
+ &. #ac (auto1 @(TAccum (D2 a)))
+ &. #tl (d2ace (select SAccum des)))
+ (#d :++: #body :++: #ac :++: #tl)
+ (#ac :++: #d :++: #body :++: #tl))
+ (EPair ext e2 (ENil ext))
+
+ | let accumMap' = case argids of
+ Just (VIArr i _) -> VarMap.insert i (STAccum (d2M argty)) IZ (VarMap.sink1 accumMap)
+ _ -> VarMap.sink1 accumMap
+ , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' sd expr ->
+ let library = #d (auto1 @sd)
+ &. #p (auto1 @(D1 a))
+ &. #body (subList (bindingsBinds e0) subtape)
+ &. #ac (auto1 @(TAccum (D2 a)))
+ &. #tl (d2ace (select SAccum des))
+ in
+ RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub (spDense (d2M argty)) $
+ let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in
+ EWith ext (d2M argty) (EDeepZero ext (d2M argty) (d2deepZeroInfo argty (EVar ext (d1 argty) primalIdx))) $
+ weakenExpr (autoWeak library
+ (#d :++: #body :++: #ac :++: #tl)
+ (#ac :++: #d :++: (#body :++: #p) :++: #tl))
+ e2
+
+ SDiscr
+ | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr
+ , Refl <- lemAppendNil @tapebinds ->
+ RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2
+
+drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False)
+ => Descr env sto
+ -> VarMap Int (D2AcE (Select env sto "accum"))
+ -> (STy a, Storage s)
+ -> Sparse (D2 t) dt
+ -> Expr ValId (a : env) t
+ -> (forall provars shbinds tape d2a'.
+ SList STy provars
+ -> Subenv (D2E (Select env sto "merge")) (D2E provars)
+ -> Bindings Ex (D1E env) (D1E provars) -- accum-promoted free variables of which we need a primal in the reverse pass (to initialise the accumulator)
+ -> Bindings Ex (D1 a : D1E env) shbinds
+ -> Ex (Append shbinds (D1 a : D1E env)) (D1 t)
+ -> Ex (Append shbinds (D1 a : D1E env)) tape
+ -> Sparse (D2 a) d2a'
+ -> (forall env' b.
+ D1E provars :> env'
+ -> Ex (Append (D2AcE provars) env') b
+ -> Ex ( env') (TPair b (Tup (D2E provars))))
+ -> Ex (tape : dt : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a'
+ -> r)
+ -> r
+drevLambda des accumMap (argty, argsto) sd origef k =
+ let t = typeOf origef in
+ deleteUnused (descrList des) (occEnvPopSome (occCountAll origef)) $ \(usedSub :: Subenv env env') ->
+ let ef = unsafeWeakenWithSubenv (SEYesR usedSub) origef in
+ subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed ->
+ accumPromote (applySparse sd (d2 t)) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->
+ let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in
+ let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in
+ let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in
+ case prf1 prodes argty argsto of { Refl ->
+ case drev (prodes `DPush` (argty, Nothing, argsto)) accumMapPro sd ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 ->
+ let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in
+ extractContrib prodes argty argsto subEf $ \argSp getSparseArg ->
+ let library = #fbinds (bindingsBinds ef0)
+ &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf)
+ &. #ftape (auto1 @(Tape e_tape))
+ &. #arg (d1 argty `SCons` SNil)
+ &. #d (applySparse sd (d2 t) `SCons` SNil)
+ &. #d1env (desD1E des)
+ &. #d1env' (desD1E usedDes)
+ &. #propr (d1e envPro)
+ &. #d2acUsed (d2ace (select SAccum usedDes))
+ &. #d2acEnv (d2ace (select SAccum des))
+ &. #d2acPro (d2ace envPro)
+ &. #efPrerebinds efPrerebinds in
+ k envPro
+ (subenvD2E (subenvCompose subMergeUsed proSub))
+ mergePrimalBindings
+ (fst (weakenBindingsE (WCopy (wUndoSubenv subD1eUsed)) ef0))
+ (weakenExpr (autoWeak library (#fbinds :++: #arg :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
+ (#fbinds :++: #arg :++: #d1env))
+ ef1)
+ (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: #arg :++: #d1env)))
+ argSp
+ (\wpro1 body ->
+ uninvertTup (d2e envPro) (typeOf body) $
+ makeAccumulators wpro1 envPro $
+ body)
+ (letBinds (efRebinds IZ) $
+ weakenExpr
+ (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
+ ((#ftapebinds :++: #efPrerebinds) :++: #ftape :++: #d :++: #d2acPro :++: #d2acEnv)
+ .> wPro (subList (bindingsBinds ef0) subtapeEf))
+ (getSparseArg ef2))
+ }}
+ where
+ extractContrib :: (Select env sto "merge" ~ '[], (s == "accum") ~ False)
+ => proxy env sto -> proxy2 a -> Storage s
+ -- if s == "merge", this simplifies to SubenvS '[D2 a] t'
+ -- if s == "discr", this simplifies to SubenvS '[] t'
+ -> SubenvS (D2E (Select (a : env) (s : sto) "merge")) t'
+ -> (forall d'. Sparse (D2 a) d' -> (forall env'. Ex env' (Tup t') -> Ex env' d') -> r) -> r
+ extractContrib _ _ SMerge (SENo SETop) k' = k' SpAbsent id
+ extractContrib _ _ SMerge (SEYes s SETop) k' = k' s (ESnd ext)
+ extractContrib _ _ SDiscr SETop k' = k' SpAbsent id
+
+ prf1 :: (s == "accum") ~ False => proxy env sto -> proxy2 a -> Storage s
+ -> Select (a : env) (s : sto) "accum" :~: Select env sto "accum"
+ prf1 _ _ SMerge = Refl
+ prf1 _ _ SDiscr = Refl
+
+-- TODO: proper primal-only transform that doesn't depend on D1 = Id
+drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t)
+drevPrimal des e
+ | Refl <- d1Identity (typeOf e)
+ , Refl <- d1eIdentity (descrList des)
+ = mapExt (const ext) e