{-# LANGUAGE DataKinds #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ImplicitParams #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -- I want to bring various type variables in scope using type annotations in -- patterns, but I don't want to have to mention all the other type parameters -- of the types in question as well then. Partial type signatures (with '_') are -- useful here. {-# LANGUAGE PartialTypeSignatures #-} {-# OPTIONS -Wno-partial-type-signatures #-} module CHAD ( drev, freezeRet, CHADConfig(..), defaultConfig, Storage(..), Descr(..), Select, ) where import Data.Functor.Const import Data.Some import Data.Type.Bool (If) import Data.Type.Equality (type (==), testEquality) import GHC.Stack (HasCallStack) import Analysis.Identity (ValId(..), validSplitEither) import AST import AST.Bindings import AST.Count import AST.Env import AST.Sparse import AST.Weaken.Auto import CHAD.EnvDescr import CHAD.Types import Data import qualified Data.VarMap as VarMap import Data.VarMap (VarMap) import Lemmas ------------------------------ TAPES AND BINDINGS ------------------------------ type family Tape binds where Tape '[] = TNil Tape (t : ts) = TPair t (Tape ts) tapeTy :: SList STy binds -> STy (Tape binds) tapeTy SNil = STNil tapeTy (SCons t ts) = STPair t (tapeTy ts) bindingsCollectTape :: SList STy binds -> Subenv binds tapebinds -> binds :> env2 -> Ex env2 (Tape tapebinds) bindingsCollectTape SNil SETop _ = ENil ext bindingsCollectTape (t `SCons` binds) (SEYesR sub) w = EPair ext (EVar ext t (w @> IZ)) (bindingsCollectTape binds sub (w .> WSink)) bindingsCollectTape (_ `SCons` binds) (SENo sub) w = bindingsCollectTape binds sub (w .> WSink) -- bindingsCollectTape' :: forall f env binds tapebinds env2. Bindings f env binds -> Subenv binds tapebinds -- -> Append binds env :> env2 -> Ex env2 (Tape tapebinds) -- bindingsCollectTape' binds sub w -- | Refl <- lemAppendNil @binds -- = bindingsCollectTape (bindingsBinds binds) sub (w .> wCopies @_ @_ @'[] (bindingsBinds binds) (WClosed @env)) -- In order from large to small: i.e. in reverse order from what we want, -- because in a Bindings, the head of the list is the bottom-most entry. type family TapeUnfoldings binds where TapeUnfoldings '[] = '[] TapeUnfoldings (t : ts) = Tape ts : TapeUnfoldings ts type family Reverse l where Reverse '[] = '[] Reverse (t : ts) = Append (Reverse ts) '[t] -- An expression that is always 'snd' data UnfExpr env t where UnfExSnd :: STy s -> STy t -> UnfExpr (TPair s t : env) t fromUnfExpr :: UnfExpr env t -> Ex env t fromUnfExpr (UnfExSnd s t) = ESnd ext (EVar ext (STPair s t) IZ) -- - A bunch of 'snd' expressions taking us from knowing that there's a -- 'Tape ts' in the environment (for simplicity assume it's at IZ, we'll fix -- this in reconstructBindings), to having 'Reverse (TapeUnfoldings ts)' in -- the environment. -- - In the extended environment, another bunch of let bindings (these are -- 'fst' expressions, but no need to know that statically) that project the -- fsts out of what we introduced above, one for each type in 'ts'. data Reconstructor env ts = Reconstructor (Bindings UnfExpr (Tape ts : env) (Reverse (TapeUnfoldings ts))) (Bindings Ex (Append (Reverse (TapeUnfoldings ts)) (Tape ts : env)) ts) ssnoc :: SList f ts -> f t -> SList f (Append ts '[t]) ssnoc SNil a = SCons a SNil ssnoc (SCons t ts) a = SCons t (ssnoc ts a) sreverse :: SList f ts -> SList f (Reverse ts) sreverse SNil = SNil sreverse (SCons t ts) = ssnoc (sreverse ts) t stapeUnfoldings :: SList STy ts -> SList STy (TapeUnfoldings ts) stapeUnfoldings SNil = SNil stapeUnfoldings (SCons _ ts) = SCons (tapeTy ts) (stapeUnfoldings ts) -- Puts a 'snd' at the top of an unfolder stack and grows the context variable by one. shiftUnfolder :: STy t -> SList STy ts -> Bindings UnfExpr (Tape ts : env) list -> Bindings UnfExpr (Tape (t : ts) : env) (Append list '[Tape ts]) shiftUnfolder newTy ts BTop = BPush BTop (tapeTy ts, UnfExSnd newTy (tapeTy ts)) shiftUnfolder newTy ts (BPush b (t, UnfExSnd itemTy _)) = -- Recurse on 'b', and retype the 'snd'. We need to unfold 'b' once in order -- to expand an 'Append' in the types so that things simplify just enough. -- We have an equality 'Append binds x1 ~ a : x2', where 'binds' is the list -- of bindings produced by 'b'. We want to conclude from this that -- 'binds ~ a : x3' for some 'x3', but GHC will only do that once we know -- that 'binds ~ y : ys' so that the 'Append' can expand one step, after -- which 'y ~ a' as desired. The 'case' unfolds 'b' one step. BPush (shiftUnfolder newTy ts b) (t, case b of BTop -> UnfExSnd itemTy t BPush{} -> UnfExSnd itemTy t) growRecon :: forall env t ts. STy t -> SList STy ts -> Reconstructor env ts -> Reconstructor env (t : ts) growRecon t ts (Reconstructor unfbs bs) | Refl <- lemAppendNil @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts]) , Refl <- lemAppendAssoc @ts @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts]) @(Tape (t : ts) : env) , Refl <- lemAppendAssoc @(Reverse (TapeUnfoldings ts)) @'[Tape ts] @env = Reconstructor (shiftUnfolder t ts unfbs) -- Add a 'fst' at the bottom of the builder stack. -- First we have to weaken most of 'bs' to skip one more binding in the -- unfolder stack above it. (BPush (fst (weakenBindings weakenExpr (wCopies (sappend (sreverse (stapeUnfoldings ts)) (SCons (tapeTy ts) SNil)) (WSink :: env :> (Tape (t : ts) : env))) bs)) (t ,EFst ext $ EVar ext (tapeTy (SCons t ts)) $ wSinks @(Tape (t : ts) : env) (sappend ts (sappend (sappend (sreverse (stapeUnfoldings ts)) (SCons (tapeTy ts) SNil)) SNil)) @> IZ)) buildReconstructor :: SList STy ts -> Reconstructor env ts buildReconstructor SNil = Reconstructor BTop BTop buildReconstructor (SCons t ts) = growRecon t ts (buildReconstructor ts) -- STRATEGY FOR reconstructBindings -- -- binds = [] -- e : () -- -- binds = [c] -- e : (c, ()) -- x0 = snd x1 : () -- y1 = fst e : c -- -- binds = [b, c] -- e : (b, (c, ())) -- x1 = snd e : (c, ()) -- x0 = snd x1 : () -- y1 = fst x1 : c -- y2 = fst x2 : b -- -- binds = [a, b, c] -- e : (a, (b, (c, ()))) -- x2 = snd e : (b, (c, ())) -- x1 = snd x2 : (c, ()) -- x0 = snd x1 : () -- y1 = fst x1 : c -- y2 = fst x2 : b -- y3 = fst x3 : a -- Given that in 'env' we can find a 'Tape binds', i.e. a tuple containing all -- the things in the list 'binds', we want to create a let stack that extracts -- all values from that tuple and in effect "restores" the environment -- described by 'binds'. The idea is that elsewhere, we took a slice of the -- environment and saved it all in a tuple to be restored later. We -- incidentally also add a bunch of additional bindings, namely 'Reverse -- (TapeUnfoldings binds)', so the calling code just has to skip those in -- whatever it wants to do. reconstructBindings :: SList STy binds -> Idx env (Tape binds) -> (Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds))) ,SList STy (Reverse (TapeUnfoldings binds))) reconstructBindings binds tape = let Reconstructor unf build = buildReconstructor binds in (fst $ weakenBindings weakenExpr (WIdx tape) (bconcat (mapBindings fromUnfExpr unf) build) ,sreverse (stapeUnfoldings binds)) ---------------------------------- DERIVATIVES --------------------------------- d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t) d1op (OAdd t) e = EOp ext (OAdd t) e d1op (OMul t) e = EOp ext (OMul t) e d1op (ONeg t) e = EOp ext (ONeg t) e d1op (OLt t) e = EOp ext (OLt t) e d1op (OLe t) e = EOp ext (OLe t) e d1op (OEq t) e = EOp ext (OEq t) e d1op ONot e = EOp ext ONot e d1op OAnd e = EOp ext OAnd e d1op OOr e = EOp ext OOr e d1op OIf e = EOp ext OIf e d1op ORound64 e = EOp ext ORound64 e d1op OToFl64 e = EOp ext OToFl64 e d1op (ORecip t) e = EOp ext (ORecip t) e d1op (OExp t) e = EOp ext (OExp t) e d1op (OLog t) e = EOp ext (OLog t) e d1op (OIDiv t) e = EOp ext (OIDiv t) e d1op (OMod t) e = EOp ext (OMod t) e -- | Both primal and dual must be duplicable expressions data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a)) d2op :: SOp a t -> D2Op a t d2op op = case op of OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EPair ext d d OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) (EOp ext (OMul t) (EPair ext (EFst ext e) d)) ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d OLt t -> Linear $ \_ -> pairZero t OLe t -> Linear $ \_ -> pairZero t OEq t -> Linear $ \_ -> pairZero t ONot -> Linear $ \_ -> ENil ext OAnd -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) OOr -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) OIf -> Linear $ \_ -> ENil ext ORound64 -> Linear $ \_ -> EZero ext (SMTScal STF64) (ENil ext) OToFl64 -> Linear $ \_ -> ENil ext ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d) OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d) OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d) OIDiv t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext) OMod t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext) where pairZero :: SScalTy a -> Ex env (D2 (TPair (TScal a) (TScal a))) pairZero t = ziNil t $ EPair ext (EZero ext (d2M (STScal t)) (ENil ext)) (EZero ext (d2M (STScal t)) (ENil ext)) where ziNil :: SScalTy a -> (ZeroInfo (D2s a) ~ TNil => r) -> r ziNil STI32 k = k ziNil STI64 k = k ziNil STF32 k = k ziNil STF64 k = k ziNil STBool k = k d2opUnArrangeInt :: SScalTy a -> (D2s a ~ TScal a => D2Op (TScal a) t) -> D2Op (TScal a) t d2opUnArrangeInt ty float = case ty of STI32 -> Linear $ \_ -> ENil ext STI64 -> Linear $ \_ -> ENil ext STF32 -> float STF64 -> float STBool -> Linear $ \_ -> ENil ext d2opBinArrangeInt :: SScalTy a -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t) -> D2Op (TPair (TScal a) (TScal a)) t d2opBinArrangeInt ty float = case ty of STI32 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) STI64 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) STF32 -> float STF64 -> float STBool -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) floatingD2 :: ScalIsFloating a ~ True => SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r floatingD2 STF32 k = k floatingD2 STF64 k = k integralD2 :: ScalIsIntegral a ~ True => SScalTy a -> ((D2s a ~ TNil, ScalIsNumeric a ~ True) => r) -> r integralD2 STI32 k = k integralD2 STI64 k = k desD1E :: Descr env sto -> SList STy (D1E env) desD1E = d1e . descrList -- d1W :: env :> env' -> D1E env :> D1E env' -- d1W WId = WId -- d1W WSink = WSink -- d1W (WCopy w) = WCopy (d1W w) -- d1W (WPop w) = WPop (d1W w) -- d1W (WThen u w) = WThen (d1W u) (d1W w) conv1Idx :: Idx env t -> Idx (D1E env) (D1 t) conv1Idx IZ = IZ conv1Idx (IS i) = IS (conv1Idx i) data Idx2 env sto t = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) | Idx2Me (Idx (D2E (Select env sto "merge")) (D2 t)) | Idx2Di (Idx (Select env sto "discr") t) conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t conv2Idx (DPush _ (_, _, SAccum)) IZ = Idx2Ac IZ conv2Idx (DPush _ (_, _, SMerge)) IZ = Idx2Me IZ conv2Idx (DPush _ (_, _, SDiscr)) IZ = Idx2Di IZ conv2Idx (DPush des (_, _, SAccum)) (IS i) = case conv2Idx des i of Idx2Ac j -> Idx2Ac (IS j) Idx2Me j -> Idx2Me j Idx2Di j -> Idx2Di j conv2Idx (DPush des (_, _, SMerge)) (IS i) = case conv2Idx des i of Idx2Ac j -> Idx2Ac j Idx2Me j -> Idx2Me (IS j) Idx2Di j -> Idx2Di j conv2Idx (DPush des (_, _, SDiscr)) (IS i) = case conv2Idx des i of Idx2Ac j -> Idx2Ac j Idx2Me j -> Idx2Me j Idx2Di j -> Idx2Di (IS j) conv2Idx DTop i = case i of {} opt2UnSparse :: SOp a b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b) opt2UnSparse = go . opt2 where go :: STy b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b) go (STScal STI32) SpAbsent = \_ -> ENil ext go (STScal STI64) SpAbsent = \_ -> ENil ext go (STScal STF32) SpAbsent = \_ -> EZero ext (SMTScal STF32) (ENil ext) go (STScal STF64) SpAbsent = \_ -> EZero ext (SMTScal STF64) (ENil ext) go (STScal STBool) SpAbsent = \_ -> ENil ext go (STScal STF32) SpScal = id go (STScal STF64) SpScal = id go STNil _ = \_ -> ENil ext go (STPair t1 t2) (SpPair s1 s2) = \e -> eunPair e $ \_ e1 e2 -> EPair ext (go t1 s1 e1) (go t2 s2 e2) go t _ = error $ "Primitive operations that return " ++ show t ++ " are scary" ------------------------------------ MONOIDS ----------------------------------- d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t)) d2zeroInfo STNil _ = ENil ext d2zeroInfo (STPair a b) e = eunPair e $ \_ e1 e2 -> EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2) d2zeroInfo STEither{} _ = ENil ext d2zeroInfo STLEither{} _ = ENil ext d2zeroInfo STMaybe{} _ = ENil ext d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program" zeroTup :: SList STy env0 -> D1E env0 :> env -> Ex env (Tup (D2E env0)) zeroTup SNil _ = ENil ext zeroTup (t `SCons` env) w = EPair ext (zeroTup env (WPop w)) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) ----------------------------------- SPARSITY ----------------------------------- subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env') subenvD1E SETop = SETop subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub) subenvD1E (SENo sub) = SENo (subenvD1E sub) expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a) expandSparse t sp _ e | Just Refl <- isDense (d2M t) sp = e expandSparse t (SpSparse sp) epr e = EMaybe ext (EZero ext (d2M t) (d2zeroInfo t epr)) (expandSparse t sp (weakenExpr WSink epr) (EVar ext (applySparse sp (d2 t)) IZ)) e expandSparse t SpAbsent epr _ = EZero ext (d2M t) (d2zeroInfo t epr) expandSparse (STPair t1 t2) (SpPair s1 s2) epr e = eunPair epr $ \w1 epr1 epr2 -> eunPair (weakenExpr w1 e) $ \w2 e1 e2 -> EPair ext (expandSparse t1 s1 (weakenExpr w2 epr1) e1) (expandSparse t2 s2 (weakenExpr w2 epr2) e2) expandSparse (STEither t1 t2) (SpLEither s1 s2) epr e = ELCase ext e (EZero ext (d2M (STEither t1 t2)) (ENil ext)) (ECase ext (weakenExpr WSink epr) (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ)))) (EError ext (d2 (STEither t1 t2)) "expspa r<-dl")) (ECase ext (weakenExpr WSink epr) (EError ext (d2 (STEither t1 t2)) "expspa l<-dr") (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) expandSparse (STLEither t1 t2) (SpLEither s1 s2) epr e = ELCase ext e (EZero ext (d2M (STEither t1 t2)) (ENil ext)) (ELCase ext (weakenExpr WSink epr) (EError ext (d2 (STEither t1 t2)) "expspa ln<-dl") (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ)))) (EError ext (d2 (STEither t1 t2)) "expspa lr<-dl")) (ELCase ext (weakenExpr WSink epr) (EError ext (d2 (STEither t1 t2)) "expspa ln<-dr") (EError ext (d2 (STEither t1 t2)) "expspa ll<-dr") (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) expandSparse (STMaybe t) (SpMaybe s) epr e = EMaybe ext (ENothing ext (d2 t)) (let epr' = EMaybe ext (EError ext (d1 t) "expspa n<-dj") (EVar ext (d1 t) IZ) epr in EJust ext (expandSparse t s (weakenExpr WSink epr') (EVar ext (applySparse s (d2 t)) IZ))) e expandSparse (STArr _ t) (SpArr s) epr e = ezipWith (expandSparse t s (EVar ext (d1 t) (IS IZ)) (EVar ext (applySparse s (d2 t)) IZ)) epr e expandSparse (STScal STF32) SpScal _ e = e expandSparse (STScal STF64) SpScal _ e = e expandSparse (STAccum{}) _ _ _ = error "accumulators not allowed in source program" subenvPlus :: SBool req1 -> SBool req2 -> SList SMTy env -> SubenvS env env1 -> SubenvS env env2 -> (forall env3. SubenvS env env3 -> Injection req1 (Tup env1) (Tup env3) -> Injection req2 (Tup env2) (Tup env3) -> (forall e. Ex e (Tup env1) -> Ex e (Tup env2) -> Ex e (Tup env3)) -> r) -> r subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\_ _ -> ENil ext) subenvPlus req1 req2 (SCons _ env) (SENo sub1) (SENo sub2) k = subenvPlus req1 req2 env sub1 sub2 $ \sub3 s31 s32 pl -> k (SENo sub3) s31 s32 pl subenvPlus req1 SF (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k = subenvPlus req1 SF env sub1 sub2 $ \sub3 minj13 _ pl -> k (SEYes sp1 sub3) (withInj minj13 $ \inj13 -> \e1 -> eunPair e1 $ \_ e1a e1b -> EPair ext (inj13 e1a) e1b) Noinj (\e1 e2 -> ELet ext e1 $ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) (weakenExpr WSink e2)) (ESnd ext (EVar ext (typeOf e1) IZ))) subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k = subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl -> k (SEYes (SpSparse sp1) sub3) (withInj minj13 $ \inj13 -> \e1 -> eunPair e1 $ \_ e1a e1b -> EPair ext (inj13 e1a) (EJust ext e1b)) (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t)))) (\e1 e2 -> ELet ext e1 $ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) (weakenExpr WSink e2)) (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ)))) subenvPlus req1 req2 (SCons t env) sub1@SENo{} sub2@SEYes{} k = subenvPlus req2 req1 (SCons t env) sub2 sub1 $ \sub3 minj23 minj13 pl -> k sub3 minj13 minj23 (flip pl) subenvPlus req1 req2 (SCons t env) (SEYes sp1 sub1) (SEYes sp2 sub2) k = subenvPlus req1 req2 env sub1 sub2 $ \sub3 minj13 minj23 pl -> sparsePlusS req1 req2 t sp1 sp2 $ \sp3 mTinj13 mTinj23 plus -> k (SEYes sp3 sub3) (withInj2 minj13 mTinj13 $ \inj13 tinj13 -> \e1 -> eunPair e1 $ \_ e1a e1b -> EPair ext (inj13 e1a) (tinj13 e1b)) (withInj2 minj23 mTinj23 $ \inj23 tinj23 -> \e2 -> eunPair e2 $ \_ e2a e2b -> EPair ext (inj23 e2a) (tinj23 e2b)) (\e1 e2 -> ELet ext e1 $ ELet ext (weakenExpr WSink e2) $ EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) (EFst ext (EVar ext (typeOf e2) IZ))) (plus (ESnd ext (EVar ext (typeOf e1) (IS IZ))) (ESnd ext (EVar ext (typeOf e2) IZ)))) expandSubenvZeros :: D1E env0 :> env -> SList STy env0 -> SubenvS (D2E env0) contribs -> Ex env (Tup contribs) -> Ex env (Tup (D2E env0)) expandSubenvZeros _ SNil SETop _ = ENil ext expandSubenvZeros w (SCons t ts) (SEYes sp sub) e = eunPair e $ \w1 e1 e2 -> EPair ext (expandSubenvZeros (w1 .> WPop w) ts sub e1) (expandSparse t sp (EVar ext (d1 t) (w1 .> w @> IZ)) e2) expandSubenvZeros w (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros (WPop w) ts sub e) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) assertSubenvEmpty :: HasCallStack => Subenv' s env env' -> env' :~: '[] assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl assertSubenvEmpty SETop = Refl assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" --------------------------------- ACCUMULATORS --------------------------------- makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) makeAccumulators _ SNil e = e makeAccumulators w (t `SCons` envpro) e = makeAccumulators (WPop w) envpro $ EWith ext (d2M t) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) uninvertTup SNil _ e = EPair ext e (ENil ext) uninvertTup (t `SCons` list) tcore e = ELet ext (uninvertTup list (STPair tcore t) e) $ let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding in EPair ext (EFst ext (EFst ext (EVar ext recT IZ))) (EPair ext (ESnd ext (EVar ext recT IZ)) (ESnd ext (EFst ext (EVar ext recT IZ)))) fromArrayValId :: Maybe (ValId t) -> Maybe Int fromArrayValId (Just (VIArr i _)) = Just i fromArrayValId _ = Nothing accumPromote :: forall dt env sto proxy r. proxy dt -> Descr env sto -> (forall stoRepl envPro. (Select env stoRepl "merge" ~ '[]) => Descr env stoRepl -- ^ A revised environment description that switches -- arrays (used in the OccEnv) that are currently on -- "merge" storage, to "accum" storage. -> SList STy envPro -- ^ New entries on top of the original dual environment, -- that house the accumulators for the promoted arrays in -- the original environment. -> Subenv (Select env sto "merge") envPro -- ^ The promoted entries were merge entries in the -- original environment. -> Subenv (D2AcE (Select env stoRepl "accum")) (D2AcE (Select env sto "accum")) -- ^ All entries that were accumulators are still -- accumulators. -> VarMap Int (D2AcE (Select env stoRepl "accum")) -- ^ Accumulator map for _only_ the the newly allocated -- accumulators. -> (forall shbinds. SList STy shbinds -> (dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) :> Append (D2AcE envPro) (dt : Append shbinds (D2AcE (Select env sto "accum")))) -- ^ A weakening that converts a computation in the -- revised environment to one in the original environment -- extended with some accumulators. -> r) -> r accumPromote _ DTop k = k DTop SNil SETop SETop VarMap.empty (\_ -> WId) accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of -- Accumulators are left as-is SAccum -> accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> k (storepl `DPush` (t, vid, SAccum)) envpro prosub (SEYesR accrevsub) (VarMap.sink1 accumMap) (\shbinds -> autoWeak (#pro (d2ace envpro) &. #d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr))) (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) (#pro :++: #d :++: #shb :++: #acc :++: #tl) .> WCopy (wf shbinds) .> autoWeak (#d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl))) (#d :++: #shb :++: #acc :++: #tl) (#acc :++: (#d :++: #shb :++: #tl))) SMerge -> case t of -- Discrete values are left as-is _ | isDiscrete t -> accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap' wf -> k (storepl `DPush` (t, vid, SDiscr)) envpro (SENo prosub) accrevsub accumMap' wf -- Values with "merge" storage are promoted to an accumulator in envPro _ -> accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> k (storepl `DPush` (t, vid, SAccum)) (t `SCons` envpro) (SEYesR prosub) (SENo accrevsub) (let accumMap' = VarMap.sink1 accumMap in case fromArrayValId vid of Just i -> VarMap.insert i (STAccum (d2M t)) IZ accumMap' Nothing -> accumMap') (\(shbinds :: SList _ shbinds) -> let shbindsC = slistMap (\_ -> Const ()) shbinds in -- wf: -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) -- WCopy wf: -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) -- WPICK: ^ THESE TWO || -- goal: | ARE EQUAL || -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) WCopy (wf shbinds) .> WPick @(TAccum (D2 t)) @(dt : shbinds) (Const () `SCons` shbindsC) (WId @(D2AcE (Select env1 stoRepl "accum")))) -- Discrete values are left as-is, nothing to do SDiscr -> accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> k (storepl `DPush` (t, vid, SDiscr)) envpro prosub accrevsub accumMap wf where isDiscrete :: STy t' -> Bool isDiscrete = \case STNil -> True STPair a b -> isDiscrete a && isDiscrete b STEither a b -> isDiscrete a && isDiscrete b STLEither a b -> isDiscrete a && isDiscrete b STMaybe a -> isDiscrete a STArr _ a -> isDiscrete a STScal st -> case st of STI32 -> True STI64 -> True STF32 -> False STF64 -> False STBool -> True STAccum{} -> False ---------------------------- RETURN TRIPLE FROM CHAD --------------------------- data Ret env0 sto sd t = forall shbinds tapebinds contribs. Ret (Bindings Ex (D1E env0) shbinds) -- shared binds (Subenv shbinds tapebinds) (Ex (Append shbinds (D1E env0)) (D1 t)) (SubenvS (D2E (Select env0 sto "merge")) contribs) (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)) deriving instance Show (Ret env0 sto sd t) type data TyTyPair = MkTyTyPair Ty Ty data SingleRet env0 sto (pair :: TyTyPair) = forall shbinds tapebinds. SingleRet (Bindings Ex (D1E env0) shbinds) -- shared binds (Subenv shbinds tapebinds) (RetPair env0 sto (D1E env0) shbinds tapebinds pair) -- pattern Ret1 :: forall env0 sto Bindings Ex (D1E env0) shbinds -- -> Subenv shbinds tapebinds -- -> Ex (Append shbinds (D1E env0)) (D1 t) -- -> SubenvS (D2E (Select env0 sto "merge")) contribs -- -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs) -- -> SingleRet env0 sto (MkTyTyPair sd t) -- pattern Ret1 e0 subtape e1 sub e2 = SingleRet e0 subtape (RetPair e1 sub e2) -- {-# COMPLETE Ret1 #-} data RetPair env0 sto env shbinds tapebinds (pair :: TyTyPair) where RetPair :: forall sd t contribs -- existentials env0 sto env shbinds tapebinds. -- universals Ex (Append shbinds env) (D1 t) -> SubenvS (D2E (Select env0 sto "merge")) contribs -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs) -> RetPair env0 sto env shbinds tapebinds (MkTyTyPair sd t) deriving instance Show (RetPair env0 sto env shbinds tapebinds pair) data Rets env0 sto env list = forall shbinds tapebinds. Rets (Bindings Ex env shbinds) (Subenv shbinds tapebinds) (SList (RetPair env0 sto env shbinds tapebinds) list) deriving instance Show (Rets env0 sto env list) toSingleRet :: Ret env0 sto sd t -> SingleRet env0 sto (MkTyTyPair sd t) toSingleRet (Ret e0 subtape e1 sub e2) = SingleRet e0 subtape (RetPair e1 sub e2) weakenRetPair :: SList STy shbinds -> env :> env' -> RetPair env0 sto env shbinds tapebinds pair -> RetPair env0 sto env' shbinds tapebinds pair weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2 weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list weakenRets w (Rets binds tapesub list) = let (binds', _) = weakenBindings weakenExpr w binds in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list) rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto pair f. Descr env0 sto -> SList f b1 -> SList f b2 -> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2 -> RetPair env0 sto (Append b1 env) b2 tapebinds2 pair -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) pair rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair @sd e1 sub e2) | Refl <- lemAppendAssoc @b2 @b1 @env = RetPair e1 sub (weakenExpr (autoWeak (#d (auto1 @sd) &. #t2 (subList b2 subtape2) &. #t1 (subList b1 subtape1) &. #tl (d2ace (select SAccum descr))) (#d :++: (#t2 :++: #tl)) (#d :++: ((#t2 :++: #t1) :++: #tl))) e2) retConcat :: forall env0 sto list. Descr env0 sto -> SList (SingleRet env0 sto) list -> Rets env0 sto (D1E env0) list retConcat _ SNil = Rets BTop SETop SNil retConcat descr (SCons (SingleRet (e0 :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) (RetPair e1 sub e2)) list) | Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs <- weakenRets (sinkWithBindings e0) (retConcat descr list) , Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0) , Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum")) = Rets (bconcat e0 binds) (subenvConcat subtape subtape2) (SCons (RetPair (weakenExpr (sinkWithBindings binds) e1) sub (weakenExpr (WCopy (sinkWithSubenv subtape2)) e2)) (slistMap (rebaseRetPair descr (bindingsBinds e0) (bindingsBinds binds) subtape subtape2) pairs)) freezeRet :: Descr env sto -> Ret env sto (D2 t) t -> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge")))) freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ _ t) = let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0 e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2 tContribs = tTup (slistMap fromSMTy (subList (d2eM (select SMerge descr)) sub)) library = #d (auto1 @(D2 t)) &. #tape (subList (bindingsBinds e0) subtape) &. #shbinds (bindingsBinds e0) &. #d2ace (d2ace (select SAccum descr)) &. #tl (desD1E descr) &. #contribs (SCons tContribs SNil) in letBinds e0' $ EPair ext (weakenExpr wInsertD2Ac e1) (ELet ext (weakenExpr (autoWeak library (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl) (#shbinds :++: #d :++: #d2ace :++: #tl)) e2') $ expandSubenvZeros (autoWeak library #tl (#contribs :++: #shbinds :++: #d :++: #d2ace :++: #tl) .> wUndoSubenv (subenvD1E (selectSub SMerge descr))) (select SMerge descr) sub (EVar ext tContribs IZ)) ---------------------------- THE CHAD TRANSFORMATION --------------------------- drev :: forall env sto sd t. (?config :: CHADConfig) => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) -> Sparse (D2 t) sd -> Expr ValId env t -> Ret env sto sd t drev des _ sd | isAbsent sd = \e -> Ret BTop SETop (drevPrimal des e) (subenvNone (d2e (select SMerge des))) (ENil ext) drev _ _ SpAbsent = error "Absent should be isAbsent" drev des accumMap (SpSparse sd) = \e -> case drev des accumMap sd e of { Ret e0 subtape e1 sub e2 -> subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> Ret e0 subtape e1 sub' (emaybe (evar IZ) (inj2 (ENil ext)) (inj1 (weakenExpr (WCopy WSink) e2))) } drev des accumMap sd = \case EVar _ t i -> case conv2Idx des i of Idx2Ac accI -> Ret BTop SETop (EVar ext (d1 t) (conv1Idx i)) (subenvNone (d2e (select SMerge des))) (let ty = applySparse sd (d2M t) in EAccum ext (d2M t) (_ sd) (ENil ext) (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) Idx2Me tupI -> Ret BTop SETop (EVar ext (d1 t) (conv1Idx i)) (subenvOnehot (d2e (select SMerge des)) tupI sd) (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t)) IZ)) Idx2Di _ -> Ret BTop SETop (EVar ext (d1 t) (conv1Idx i)) (subenvNone (d2e (select SMerge des))) (ENil ext) ELet _ (rhs :: Expr _ _ a) body | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) rhs1 subRHS rhs2 <- drev des accumMap sdBody rhs , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 , Refl <- lemAppendAssoc @body_shbinds @'[D1 a] @rhs_shbinds , Refl <- lemAppendAssoc @body_shbinds @(D1 a : rhs_shbinds) @(D1E env) , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> subenvPlus SF SF (d2eM (select SMerge des)) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> let bodyResType = STPair (contribTupTy des subBody) (applySparse sdBody (d2 (typeOf rhs))) in Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') (subenvConcat subtapeRHS subtapeBody) (weakenExpr wbody0' body1) subBoth (ELet ext (weakenExpr (autoWeak (#d (auto1 @sd) &. #body (subList (bindingsBinds body0 `sappend` SCons (d1 (typeOf rhs)) SNil) subtapeBody) &. #rhs (subList (bindingsBinds rhs0) subtapeRHS) &. #tl (d2ace (select SAccum des))) (#d :++: #body :++: #tl) (#d :++: (#body :++: #rhs) :++: #tl)) body2) $ ELet ext (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $ plus_RHS_Body (EVar ext (contribTupTy des subRHS) IZ) (EFst ext (EVar ext bodyResType (IS IZ)))) EPair _ a b | SpPair sd1 sd2 <- sd , Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) <- retConcat des $ toSingleRet (drev des accumMap sd1 a) `SCons` toSingleRet (drev des accumMap sd2 b) `SCons` SNil , let dt = STPair (applySparse sd1 (d2 (typeOf a))) (applySparse sd2 (d2 (typeOf b))) -> subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B -> Ret binds subtape (EPair ext a1 b1) subBoth (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) (weakenExpr (WCopy WSink) a2)) $ ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ))) (weakenExpr (WCopy (WSink .> WSink)) b2)) $ plus_A_B (EVar ext (contribTupTy des subA) (IS IZ)) (EVar ext (contribTupTy des subB) IZ)) EFst _ e | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair sd SpAbsent) e , STPair t1 _ <- typeOf e -> Ret e0 subtape (EFst ext e1) sub (ELet ext (EPair ext (EVar ext (applySparse sd (d2 t1)) IZ) (ENil ext)) $ weakenExpr (WCopy WSink) e2) ESnd _ e | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair SpAbsent sd) e , STPair _ t2 <- typeOf e -> Ret e0 subtape (ESnd ext e1) sub (ELet ext (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t2)) IZ)) $ weakenExpr (WCopy WSink) e2) -- Don't need to handle ENil, because its cotangent is always absent! -- ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (d2e (select SMerge des))) (ENil ext) EInl _ t2 e | SpLEither sd1 sd2 <- sd , Ret e0 subtape e1 sub e2 <- drev des accumMap sd1 e -> subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> Ret e0 subtape (EInl ext (d1 t2) e1) sub' (ELCase ext (EVar ext (STLEither (applySparse sd1 (d2 (typeOf e))) (applySparse sd2 (d2 t2))) IZ) (inj2 $ ENil ext) (inj1 $ weakenExpr (WCopy WSink) e2) (EError ext (contribTupTy des sub') "inl<-dinr")) EInr _ t1 e | SpLEither sd1 sd2 <- sd , Ret e0 subtape e1 sub e2 <- drev des accumMap sd2 e -> subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> Ret e0 subtape (EInr ext (d1 t1) e1) sub' (ELCase ext (EVar ext (STLEither (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 (typeOf e)))) IZ) (inj2 $ ENil ext) (EError ext (contribTupTy des sub') "inr<-dinl") (inj1 $ weakenExpr (WCopy WSink) e2)) ECase _ e (a :: Expr _ _ t) b | STEither (t1 :: STy a) (t2 :: STy b) <- typeOf e , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge , let (bindids1, bindids2) = validSplitEither (extOf e) , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA sd1 a2 <- drevScoped des accumMap t1 storage1 bindids1 sd a , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB sd2 b2 <- drevScoped des accumMap t2 storage2 bindids2 sd b , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap (SpLEither sd1 sd2) e , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) , let subtapeListA = subList (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA , let subtapeListB = subList (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB , let tapeA = tapeTy subtapeListA , let tapeB = tapeTy subtapeListB , let collectA = bindingsCollectTape @_ @_ @(Append rhs_a_binds (D1 a : Append e_binds (D1E env))) (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA , let collectB = bindingsCollectTape @_ @_ @(Append rhs_b_binds (D1 b : Append e_binds (D1E env))) (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB) , let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0 , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0 , Refl <- lemAppendNil @(Append rhs_a_binds '[D1 a]) , Refl <- lemAppendNil @(Append rhs_b_binds '[D1 b]) , Refl <- lemAppendAssoc @rhs_a_binds @'[D1 a] @(D1E env) , Refl <- lemAppendAssoc @rhs_b_binds @'[D1 b] @(D1E env) , let wa0'' = wa0' .> wCopies (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) (WClosed @(D1E env)) , let wb0'' = wb0' .> wCopies (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) (WClosed @(D1E env)) -> subenvPlus ST ST (d2eM (select SMerge des)) subA subB $ \subAB (Inj sAB_A) (Inj sAB_B) _ -> subenvPlus SF SF (d2eM (select SMerge des)) subAB subE $ \subOut _ _ plus_AB_E -> Ret (e0 `BPush` (tPrimal, ECase ext e1 (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0'')))) (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0'')))))) (SEYesR subtapeE) (EFst ext (EVar ext tPrimal IZ)) subOut (elet (ECase ext (ESnd ext (EVar ext tPrimal (IS IZ))) (let (rebinds, prerebinds) = reconstructBindings subtapeListA IZ in letBinds rebinds $ ELet ext (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_a_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListA prerebinds) @> IS IZ)) $ elet (weakenExpr (autoWeak (#d (auto1 @sd) &. #ta0 subtapeListA &. #prea0 prerebinds &. #recon (tapeA `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil) &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE) &. #tl (d2ace (select SAccum des))) (#d :++: #ta0 :++: #tl) (#d :++: (#ta0 :++: #prea0) :++: #recon :++: #binds :++: #tl)) a2) $ EPair ext (sAB_A $ EFst ext (evar IZ)) (ELInl ext (applySparse sd2 (d2 t2)) (ESnd ext (evar IZ)))) (let (rebinds, prerebinds) = reconstructBindings subtapeListB IZ in letBinds rebinds $ ELet ext (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_b_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListB prerebinds) @> IS IZ)) $ elet (weakenExpr (autoWeak (#d (auto1 @sd) &. #tb0 subtapeListB &. #preb0 prerebinds &. #recon (tapeB `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil) &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE) &. #tl (d2ace (select SAccum des))) (#d :++: #tb0 :++: #tl) (#d :++: (#tb0 :++: #preb0) :++: #recon :++: #binds :++: #tl)) b2) $ EPair ext (sAB_B $ EFst ext (evar IZ)) (ELInr ext (applySparse sd1 (d2 t1)) (ESnd ext (evar IZ))))) $ plus_AB_E (EFst ext (evar IZ)) (ELet ext (ESnd ext (evar IZ)) $ weakenExpr (WCopy (wSinks' @[_,_,_])) e2)) EConst _ t val -> Ret BTop SETop (EConst ext t val) (subenvNone (d2e (select SMerge des))) (ENil ext) EOp _ op e | Ret e0 subtape e1 sub e2 <- drev des accumMap (spDense (d2M (opt1 op))) e -> case d2op op of Linear d2opfun -> Ret e0 subtape (d1op op e1) sub (ELet ext (d2opfun (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ))) (weakenExpr (WCopy WSink) e2)) Nonlinear d2opfun -> Ret (e0 `BPush` (d1 (typeOf e), e1)) (SEYesR subtape) (d1op op $ EVar ext (d1 (typeOf e)) IZ) sub (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ)) (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ))) (weakenExpr (WCopy (wSinks' @[_,_])) e2)) ECustom _ _ tb storety srce pr du a b -- allowed to ignore a2 because 'a' is the part of the input that is inactive | Ret b0 bsubtape b1 bsub b2 <- drev des accumMap (spDense (d2M tb)) b -> case isDense (d2M (typeOf srce)) sd of Just Refl -> Ret (b0 `BPush` (d1 (typeOf a), weakenExpr (sinkWithBindings b0) (drevPrimal des a)) `BPush` (typeOf b1, weakenExpr WSink b1) `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ))) (SEYesR (SENo (SENo (SENo bsubtape)))) (EFst ext (EVar ext (typeOf pr) (IS IZ))) bsub (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ weakenExpr (WCopy (WSink .> WSink)) b2) Nothing -> Ret (b0 `BPush` (d1 (typeOf a), weakenExpr (sinkWithBindings b0) (drevPrimal des a)) `BPush` (typeOf b1, weakenExpr WSink b1) `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr))) (SEYesR (SENo (SENo bsubtape))) (EFst ext (EVar ext (typeOf pr) IZ)) bsub (ELet ext (ESnd ext (EVar ext (typeOf pr) (IS IZ))) $ -- tape ELet ext (expandSparse (typeOf srce) sd -- expanded incoming cotangent (EFst ext (EVar ext (typeOf pr) (IS (IS IZ)))) (EVar ext (applySparse sd (d2 (typeOf srce))) (IS IZ))) $ ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) b2) ERecompute _ e -> deleteUnused (descrList des) (occCountAll e) $ \usedSub -> let smallE = unsafeWeakenWithSubenv usedSub e in subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> case drev usedDes (VarMap.subMap subAccumUsed accumMap) sd smallE of { Ret e0 subtape _ sub e2 -> let subMergeUsed' = subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E subMergeUsed) in Ret (collectBindings (desD1E des) subD1eUsed) (subenvAll (desD1E usedDes)) (weakenExpr (wSinks (desD1E usedDes)) $ drevPrimal des e) (subenvCompose subMergeUsed' sub) (letBinds (fst (weakenBindings weakenExpr (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $ weakenExpr (autoWeak (#d (auto1 @sd) &. #shbinds (bindingsBinds e0) &. #tape (subList (bindingsBinds e0) subtape) &. #d1env (desD1E usedDes) &. #tl' (d2ace (select SAccum usedDes)) &. #tl (d2ace (select SAccum des))) (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: LPreW #tl' #tl (wUndoSubenv subAccumUsed)) (#shbinds :++: #d :++: #d1env :++: #tl)) e2) } EError _ t s -> Ret BTop SETop (EError ext (d1 t) s) (subenvNone (d2e (select SMerge des))) (ENil ext) EConstArr _ n t val -> Ret BTop SETop (EConstArr ext n t val) (subenvNone (d2e (select SMerge des))) (ENil ext) EBuild _ (ndim :: SNat ndim) she (orige :: Expr _ _ eltty) | SpArr @_ @sdElt sdElt <- sd , let eltty = typeOf orige , shty :: STy shty <- tTup (sreplicate ndim tIx) , Refl <- indexTupD1Id ndim -> deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') -> let e = unsafeWeakenWithSubenv (SEYesR usedSub) orige in subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> accumPromote sdElt usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro sdElt e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> case assertSubenvEmpty sub of { Refl -> case lemAppendNil @e_binds of { Refl -> let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in let collectexpr = bindingsCollectTape (bindingsBinds e0) subtapeE in Ret (BTop `BPush` (shty, drevPrimal des she) `BPush` (STArr ndim (STPair (d1 eltty) tapety) ,EBuild ext ndim (EVar ext shty IZ) (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil) &. #sh (shty `SCons` SNil) &. #d1env (desD1E des) &. #d1env' (desD1E usedDes)) (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) (#ix :++: #sh :++: #d1env)) e0)) $ let w = autoWeak (#ix (shty `SCons` SNil) &. #sh (shty `SCons` SNil) &. #e0 (bindingsBinds e0) &. #d1env (desD1E des) &. #d1env' (desD1E usedDes)) (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) (#e0 :++: #ix :++: #sh :++: #d1env) w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env')) in EPair ext (weakenExpr w e1) (collectexpr w'))) `BPush` (STArr ndim tapety, emap (ESnd ext (EVar ext (STPair (d1 eltty) tapety) IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ))) (SEYesR (SENo (SEYesR SETop))) (emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) (let sinkOverEnvPro = wSinks @(sd : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in ESnd ext $ uninvertTup (d2e envPro) (STArr ndim STNil) $ -- TODO: what's happening here is that because of the sparsity -- rewrite, makeAccumulators needs primals where it previously -- didn't. The build derivative is currently not saving those -- primals, so the hole below cannot currently be filled. The -- appropriate primals (waves hands) need to be stored, so that a -- weakening can be provided here. makeAccumulators @_ @_ @(TArr ndim TNil) (_ (subenvCompose subMergeUsed proSub)) envPro $ EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $ -- the cotangent for this element ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ)) (EVar ext shty IZ)) $ -- the tape for this element ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) (EVar ext shty (IS IZ))) $ let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ in letBinds rebinds $ weakenExpr (autoWeak (#d (auto1 @sdElt) &. #pro (d2ace envPro) &. #etape (subList (bindingsBinds e0) subtapeE) &. #prerebinds prerebinds &. #tape (auto1 @(Tape e_tape)) &. #ix (auto1 @shty) &. #darr (auto1 @(TArr ndim sdElt)) &. #tapearr (auto1 @(TArr ndim (Tape e_tape))) &. #sh (auto1 @shty) &. #d2acUsed (d2ace (select SAccum usedDes)) &. #d2acEnv (d2ace (select SAccum des))) (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #d2acEnv) .> wPro (subList (bindingsBinds e0) subtapeE)) e2) }}} EUnit _ e | SpArr sdElt <- sd , Ret e0 subtape e1 sub e2 <- drev des accumMap sdElt e -> Ret e0 subtape (EUnit ext e1) sub (ELet ext (EIdx0 ext (EVar ext (STArr SZ (applySparse sdElt (d2 (typeOf e)))) IZ)) $ weakenExpr (WCopy WSink) e2) EReplicate1Inner _ en e -- We're allowed to ignore en2 here because the output of 'ei' is discrete. | Rets binds subtape (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil) <- retConcat des $ drev des accumMap en `SCons` drev des accumMap e `SCons` SNil , let STArr ndim eltty = typeOf e -> Ret binds subtape (EReplicate1Inner ext en1 e1) sub (EMaybe ext (zeroTup (subList (select SMerge des) sub)) (ELet ext (EJust ext (EFold1Inner ext Commut (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) (ezeroD2 eltty) (EVar ext (STArr (SS ndim) (d2 eltty)) IZ))) $ weakenExpr (WCopy (WSink .> WSink)) e2) (EVar ext (d2 (STArr (SS ndim) eltty)) IZ)) EIdx0 _ e | Ret e0 subtape e1 sub e2 <- drev des accumMap e , STArr _ t <- typeOf e -> Ret e0 subtape (EIdx0 ext e1) sub (ELet ext (EJust ext (EUnit ext (EVar ext (d2 t) IZ))) $ weakenExpr (WCopy WSink) e2) EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead" {- EIdx1 _ e ei -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil , STArr (SS n) eltty <- typeOf e -> Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1) `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ))) (SEYesR (SENo subtape)) (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ)) (weakenExpr (WSink .> WSink) ei1)) sub (ELet ext (ebuildUp1 n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) (EVar ext (STArr n (d2 eltty)) (IS IZ))) $ weakenExpr (WCopy (WSink .> WSink)) e2) -} EIdx _ e ei -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil , STArr n eltty <- typeOf e , Refl <- indexTupD1Id n , Refl <- lemZeroInfoD2 eltty , let tIxN = tTup (sreplicate n tIx) -> Ret (binds `BPush` (STArr n (d1 eltty), e1) `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ)) `BPush` (tIxN, weakenExpr (WSink .> WSink) ei1)) (SEYesR (SEYesR (SENo subtape))) (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) sub (ELet ext (EOneHot ext (d2M (STArr n eltty)) (SAPJust (SAPArrIdx SAPHere)) (EPair ext (EPair ext (EVar ext tIxN (IS IZ)) (EBuild ext n (EVar ext tIxN (IS (IS IZ))) (ENil ext))) (ENil ext)) (EVar ext (d2 eltty) IZ)) $ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) EShape _ e -- Allowed to ignore e2 here because the output of EShape is discrete, -- hence we'd be passing a zero cotangent to e2 anyway. | Ret e0 subtape e1 _ _ <- drev des accumMap e , STArr n _ <- typeOf e , Refl <- indexTupD1Id n -> Ret e0 subtape (EShape ext e1) (subenvNone (select SMerge des)) (ENil ext) ESum1Inner _ e | Ret e0 subtape e1 sub e2 <- drev des accumMap e , STArr (SS n) t <- typeOf e -> Ret (e0 `BPush` (STArr (SS n) t, e1) `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ))) (SEYesR (SENo subtape)) (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ))) sub (EMaybe ext (zeroTup (subList (select SMerge des) sub)) (ELet ext (EJust ext (EReplicate1Inner ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS (IS IZ)))) (EVar ext (STArr n (d2 t)) IZ))) $ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) (EVar ext (d2 (STArr n t)) IZ)) EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e -- These should be the next to be implemented, I think EFold1Inner{} -> err_unsupported "EFold1Inner" ENothing{} -> err_unsupported "ENothing" EJust{} -> err_unsupported "EJust" EMaybe{} -> err_unsupported "EMaybe" ELNil{} -> err_unsupported "ELNil" ELInl{} -> err_unsupported "ELInl" ELInr{} -> err_unsupported "ELInr" ELCase{} -> err_unsupported "ELCase" EWith{} -> err_accum EZero{} -> err_monoid EPlus{} -> err_monoid EOneHot{} -> err_monoid where err_accum = error "Accumulator operations unsupported in the source program" err_monoid = error "Monoid operations unsupported in the source program" err_unsupported s = error $ "CHAD: unsupported " ++ s deriv_extremum :: ScalIsNumeric t' ~ True => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t'))) -> Sparse (TArr n (D2s t')) sd' -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto sd' (TArr n (TScal t')) deriv_extremum extremum e | Ret e0 subtape e1 sub e2 <- drev des accumMap e , at@(STArr (SS n) t@(STScal st)) <- typeOf e , let at' = STArr n t , let tIxN = tTup (sreplicate (SS n) tIx) = Ret (e0 `BPush` (at, e1) `BPush` (at', extremum (EVar ext at IZ))) (SEYesR (SEYesR subtape)) (EVar ext at' IZ) sub (EMaybe ext (zeroTup (subList (select SMerge des) sub)) (ELet ext (EJust ext (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS (IS IZ))))) $ eif (EOp ext (OEq st) (EPair ext (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ)) (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ))))) (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) (ezeroD2 t))) $ weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2) (EVar ext (d2 at') IZ)) contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs) contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub)) data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s) data RetScoped env0 sto a s sd t = forall shbinds tapebinds contribs sa. RetScoped (Bindings Ex (D1E (a : env0)) shbinds) -- shared binds (Subenv (Append shbinds '[D1 a]) tapebinds) (Ex (Append shbinds (D1E (a : env0))) (D1 t)) (SubenvS (D2E (Select env0 sto "merge")) contribs) -- ^ merge contributions to the _enclosing_ merge environment (Sparse (D2 a) sa) -- ^ contribution to the argument (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (If (s == "discr") (Tup contribs) (TPair (Tup contribs) sa))) -- ^ the merge contributions, plus the cotangent to the argument -- (if there is any) deriving instance Show (RetScoped env0 sto a s sd t) drevScoped :: forall a s env sto sd t. (?config :: CHADConfig) => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) -> STy a -> Storage s -> Maybe (ValId a) -> Sparse (D2 t) sd -> Expr ValId (a : env) t -> RetScoped env sto a s sd t drevScoped des accumMap argty argsto argids sd expr = case argsto of SMerge | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr , Refl <- lemAppendNil @tapebinds -> case sub of SEYes sp sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' sp e2 SENo sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' SpAbsent (EPair ext e2 (ENil ext)) SAccum | Just (VIArr i _) <- argids , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap , Just Refl <- testEquality foundTy (STAccum (d2M argty)) , Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) sd expr , Refl <- lemAppendNil @tapebinds -> -- Our contribution to the binding's cotangent _here_ is zero (absent), -- because we're contributing to an earlier binding of the same value -- instead. RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent $ let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in ELet ext (EVar ext (STAccum (d2M argty)) (WSink .> wtapebinds @> idx)) $ weakenExpr (autoWeak (#d (auto1 @sd) &. #body (subList (bindingsBinds e0) subtape) &. #ac (auto1 @(TAccum (D2 a))) &. #tl (d2ace (select SAccum des))) (#d :++: #body :++: #ac :++: #tl) (#ac :++: #d :++: #body :++: #tl)) (EPair ext e2 (ENil ext)) | let accumMap' = case argids of Just (VIArr i _) -> VarMap.insert i (STAccum (d2M argty)) IZ (VarMap.sink1 accumMap) _ -> VarMap.sink1 accumMap , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' sd expr -> let library = #d (auto1 @sd) &. #p (auto1 @(D1 a)) &. #body (subList (bindingsBinds e0) subtape) &. #ac (auto1 @(TAccum (D2 a))) &. #tl (d2ace (select SAccum des)) in RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub SpDense $ let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in EWith ext (d2M argty) (EZero ext (d2M argty) (d2zeroInfo argty (EVar ext (d1 argty) primalIdx))) $ weakenExpr (autoWeak library (#d :++: #body :++: #ac :++: #tl) (#ac :++: #d :++: (#body :++: #p) :++: #tl)) e2 SDiscr | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr , Refl <- lemAppendNil @tapebinds -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2 -- TODO: proper primal-only transform that doesn't depend on D1 = Id drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t) drevPrimal des e | Refl <- chadD1Id (typeOf e) , Refl <- chadD1EId (descrList des) = mapExt (const ext) e where chadD1Id :: STy a -> D1 a :~: a chadD1Id STNil = Refl chadD1Id (STPair a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl chadD1Id (STEither a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl chadD1Id (STLEither a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl chadD1Id (STMaybe a) | Refl <- chadD1Id a = Refl chadD1Id (STArr _ a) | Refl <- chadD1Id a = Refl chadD1Id (STScal _) = Refl chadD1Id STAccum{} = error "accumulators not allowed in source program" chadD1EId :: SList STy l -> D1E l :~: l chadD1EId SNil = Refl chadD1EId (SCons t l) | Refl <- chadD1Id t, Refl <- chadD1EId l = Refl