diff options
Diffstat (limited to 'src/CHAD/Drev.hs')
| -rw-r--r-- | src/CHAD/Drev.hs | 1583 |
1 files changed, 1583 insertions, 0 deletions
diff --git a/src/CHAD/Drev.hs b/src/CHAD/Drev.hs new file mode 100644 index 0000000..595d3c7 --- /dev/null +++ b/src/CHAD/Drev.hs @@ -0,0 +1,1583 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE ImpredicativeTypes #-} +{-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeData #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +-- I want to bring various type variables in scope using type annotations in +-- patterns, but I don't want to have to mention all the other type parameters +-- of the types in question as well then. Partial type signatures (with '_') are +-- useful here. +{-# LANGUAGE PartialTypeSignatures #-} +{-# OPTIONS -Wno-partial-type-signatures #-} +module CHAD.Drev ( + drev, + freezeRet, + CHADConfig(..), + defaultConfig, + Storage(..), + Descr(..), + Select, +) where + +import Data.Functor.Const +import Data.Some +import Data.Type.Equality (type (==), testEquality) + +import CHAD.Analysis.Identity (ValId(..), validSplitEither) +import CHAD.AST +import CHAD.AST.Bindings +import CHAD.AST.Count +import CHAD.AST.Env +import CHAD.AST.Sparse +import CHAD.AST.Weaken.Auto +import CHAD.Data +import qualified CHAD.Data.VarMap as VarMap +import CHAD.Data.VarMap (VarMap) +import CHAD.Drev.Accum +import CHAD.Drev.EnvDescr +import CHAD.Drev.Types +import CHAD.Lemmas + + +------------------------------ TAPES AND BINDINGS ------------------------------ + +type family Tape binds where + Tape '[] = TNil + Tape (t : ts) = TPair t (Tape ts) + +tapeTy :: SList STy binds -> STy (Tape binds) +tapeTy SNil = STNil +tapeTy (SCons t ts) = STPair t (tapeTy ts) + +bindingsCollectTape :: SList STy binds -> Subenv binds tapebinds + -> binds :> env2 -> Ex env2 (Tape tapebinds) +bindingsCollectTape SNil SETop _ = ENil ext +bindingsCollectTape (t `SCons` binds) (SEYesR sub) w = + EPair ext (EVar ext t (w @> IZ)) + (bindingsCollectTape binds sub (w .> WSink)) +bindingsCollectTape (_ `SCons` binds) (SENo sub) w = + bindingsCollectTape binds sub (w .> WSink) + +-- bindingsCollectTape' :: forall f env binds tapebinds env2. Bindings f env binds -> Subenv binds tapebinds +-- -> Append binds env :> env2 -> Ex env2 (Tape tapebinds) +-- bindingsCollectTape' binds sub w +-- | Refl <- lemAppendNil @binds +-- = bindingsCollectTape (bindingsBinds binds) sub (w .> wCopies @_ @_ @'[] (bindingsBinds binds) (WClosed @env)) + +-- In order from large to small: i.e. in reverse order from what we want, +-- because in a Bindings, the head of the list is the bottom-most entry. +type family TapeUnfoldings binds where + TapeUnfoldings '[] = '[] + TapeUnfoldings (t : ts) = Tape ts : TapeUnfoldings ts + +type family Reverse l where + Reverse '[] = '[] + Reverse (t : ts) = Append (Reverse ts) '[t] + +-- An expression that is always 'snd' +data UnfExpr env t where + UnfExSnd :: STy s -> STy t -> UnfExpr (TPair s t : env) t + +fromUnfExpr :: UnfExpr env t -> Ex env t +fromUnfExpr (UnfExSnd s t) = ESnd ext (EVar ext (STPair s t) IZ) + +-- - A bunch of 'snd' expressions taking us from knowing that there's a +-- 'Tape ts' in the environment (for simplicity assume it's at IZ, we'll fix +-- this in reconstructBindings), to having 'Reverse (TapeUnfoldings ts)' in +-- the environment. +-- - In the extended environment, another bunch of let bindings (these are +-- 'fst' expressions, but no need to know that statically) that project the +-- fsts out of what we introduced above, one for each type in 'ts'. +data Reconstructor env ts = + Reconstructor + (Bindings UnfExpr (Tape ts : env) (Reverse (TapeUnfoldings ts))) + (Bindings Ex (Append (Reverse (TapeUnfoldings ts)) (Tape ts : env)) ts) + +ssnoc :: SList f ts -> f t -> SList f (Append ts '[t]) +ssnoc SNil a = SCons a SNil +ssnoc (SCons t ts) a = SCons t (ssnoc ts a) + +sreverse :: SList f ts -> SList f (Reverse ts) +sreverse SNil = SNil +sreverse (SCons t ts) = ssnoc (sreverse ts) t + +stapeUnfoldings :: SList STy ts -> SList STy (TapeUnfoldings ts) +stapeUnfoldings SNil = SNil +stapeUnfoldings (SCons _ ts) = SCons (tapeTy ts) (stapeUnfoldings ts) + +-- Puts a 'snd' at the top of an unfolder stack and grows the context variable by one. +shiftUnfolder + :: STy t + -> SList STy ts + -> Bindings UnfExpr (Tape ts : env) list + -> Bindings UnfExpr (Tape (t : ts) : env) (Append list '[Tape ts]) +shiftUnfolder newTy ts BTop = BPush BTop (tapeTy ts, UnfExSnd newTy (tapeTy ts)) +shiftUnfolder newTy ts (BPush b (t, UnfExSnd itemTy _)) = + -- Recurse on 'b', and retype the 'snd'. We need to unfold 'b' once in order + -- to expand an 'Append' in the types so that things simplify just enough. + -- We have an equality 'Append binds x1 ~ a : x2', where 'binds' is the list + -- of bindings produced by 'b'. We want to conclude from this that + -- 'binds ~ a : x3' for some 'x3', but GHC will only do that once we know + -- that 'binds ~ y : ys' so that the 'Append' can expand one step, after + -- which 'y ~ a' as desired. The 'case' unfolds 'b' one step. + BPush (shiftUnfolder newTy ts b) (t, case b of BTop -> UnfExSnd itemTy t + BPush{} -> UnfExSnd itemTy t) + +growRecon :: forall env t ts. STy t -> SList STy ts -> Reconstructor env ts -> Reconstructor env (t : ts) +growRecon t ts (Reconstructor unfbs bs) + | Refl <- lemAppendNil @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts]) + , Refl <- lemAppendAssoc @ts @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts]) @(Tape (t : ts) : env) + , Refl <- lemAppendAssoc @(Reverse (TapeUnfoldings ts)) @'[Tape ts] @env + = Reconstructor + (shiftUnfolder t ts unfbs) + -- Add a 'fst' at the bottom of the builder stack. + -- First we have to weaken most of 'bs' to skip one more binding in the + -- unfolder stack above it. + (BPush (fst (weakenBindingsE + (wCopies (sappend (sreverse (stapeUnfoldings ts)) (SCons (tapeTy ts) SNil)) + (WSink :: env :> (Tape (t : ts) : env))) bs)) + (t + ,EFst ext $ EVar ext (tapeTy (SCons t ts)) $ + wSinks @(Tape (t : ts) : env) + (sappend ts + (sappend (sappend (sreverse (stapeUnfoldings ts)) + (SCons (tapeTy ts) SNil)) + SNil)) + @> IZ)) + +buildReconstructor :: SList STy ts -> Reconstructor env ts +buildReconstructor SNil = Reconstructor BTop BTop +buildReconstructor (SCons t ts) = growRecon t ts (buildReconstructor ts) + +-- STRATEGY FOR reconstructBindings +-- +-- binds = [] +-- e : () +-- +-- binds = [c] +-- e : (c, ()) +-- x0 = snd x1 : () +-- y1 = fst e : c +-- +-- binds = [b, c] +-- e : (b, (c, ())) +-- x1 = snd e : (c, ()) +-- x0 = snd x1 : () +-- y1 = fst x1 : c +-- y2 = fst x2 : b +-- +-- binds = [a, b, c] +-- e : (a, (b, (c, ()))) +-- x2 = snd e : (b, (c, ())) +-- x1 = snd x2 : (c, ()) +-- x0 = snd x1 : () +-- y1 = fst x1 : c +-- y2 = fst x2 : b +-- y3 = fst x3 : a + +-- Given that in 'env' we can find a 'Tape binds', i.e. a tuple containing all +-- the things in the list 'binds', we want to create a let stack that extracts +-- all values from that tuple and in effect "restores" the environment +-- described by 'binds'. The idea is that elsewhere, we took a slice of the +-- environment and saved it all in a tuple to be restored later. We +-- incidentally also add a bunch of additional bindings, namely 'Reverse +-- (TapeUnfoldings binds)', so the calling code just has to skip those in +-- whatever it wants to do. +reconstructBindings :: SList STy binds + -> (forall env. Idx env (Tape binds) -> Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds))) + ,SList STy (Reverse (TapeUnfoldings binds))) +reconstructBindings binds = + (\tape -> let Reconstructor unf build = buildReconstructor binds + in fst $ weakenBindingsE (WIdx tape) + (bconcat (mapBindings fromUnfExpr unf) build) + ,sreverse (stapeUnfoldings binds)) + + +---------------------------------- DERIVATIVES --------------------------------- + +d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t) +d1op (OAdd t) e = EOp ext (OAdd t) e +d1op (OMul t) e = EOp ext (OMul t) e +d1op (ONeg t) e = EOp ext (ONeg t) e +d1op (OLt t) e = EOp ext (OLt t) e +d1op (OLe t) e = EOp ext (OLe t) e +d1op (OEq t) e = EOp ext (OEq t) e +d1op ONot e = EOp ext ONot e +d1op OAnd e = EOp ext OAnd e +d1op OOr e = EOp ext OOr e +d1op OIf e = EOp ext OIf e +d1op ORound64 e = EOp ext ORound64 e +d1op OToFl64 e = EOp ext OToFl64 e +d1op (ORecip t) e = EOp ext (ORecip t) e +d1op (OExp t) e = EOp ext (OExp t) e +d1op (OLog t) e = EOp ext (OLog t) e +d1op (OIDiv t) e = EOp ext (OIDiv t) e +d1op (OMod t) e = EOp ext (OMod t) e + +-- | Both primal and dual must be duplicable expressions +data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) + | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a)) + +d2op :: SOp a t -> D2Op a t +d2op op = case op of + OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EPair ext d d + OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> + EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) + (EOp ext (OMul t) (EPair ext (EFst ext e) d)) + ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d + OLt t -> Linear $ \_ -> pairZero t + OLe t -> Linear $ \_ -> pairZero t + OEq t -> Linear $ \_ -> pairZero t + ONot -> Linear $ \_ -> ENil ext + OAnd -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + OOr -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + OIf -> Linear $ \_ -> ENil ext + ORound64 -> Linear $ \_ -> EZero ext (SMTScal STF64) (ENil ext) + OToFl64 -> Linear $ \_ -> ENil ext + ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d) + OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d) + OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d) + OIDiv t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + OMod t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + where + pairZero :: SScalTy a -> Ex env (D2 (TPair (TScal a) (TScal a))) + pairZero t = ziNil t $ EPair ext (EZero ext (d2M (STScal t)) (ENil ext)) + (EZero ext (d2M (STScal t)) (ENil ext)) + where + ziNil :: SScalTy a -> (ZeroInfo (D2s a) ~ TNil => r) -> r + ziNil STI32 k = k + ziNil STI64 k = k + ziNil STF32 k = k + ziNil STF64 k = k + ziNil STBool k = k + + d2opUnArrangeInt :: SScalTy a + -> (D2s a ~ TScal a => D2Op (TScal a) t) + -> D2Op (TScal a) t + d2opUnArrangeInt ty float = case ty of + STI32 -> Linear $ \_ -> ENil ext + STI64 -> Linear $ \_ -> ENil ext + STF32 -> float + STF64 -> float + STBool -> Linear $ \_ -> ENil ext + + d2opBinArrangeInt :: SScalTy a + -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t) + -> D2Op (TPair (TScal a) (TScal a)) t + d2opBinArrangeInt ty float = case ty of + STI32 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + STI64 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + STF32 -> float + STF64 -> float + STBool -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + + floatingD2 :: ScalIsFloating a ~ True + => SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r + floatingD2 STF32 k = k + floatingD2 STF64 k = k + + integralD2 :: ScalIsIntegral a ~ True + => SScalTy a -> ((D2s a ~ TNil, ScalIsNumeric a ~ True) => r) -> r + integralD2 STI32 k = k + integralD2 STI64 k = k + +desD1E :: Descr env sto -> SList STy (D1E env) +desD1E = d1e . descrList + +-- d1W :: env :> env' -> D1E env :> D1E env' +-- d1W WId = WId +-- d1W WSink = WSink +-- d1W (WCopy w) = WCopy (d1W w) +-- d1W (WPop w) = WPop (d1W w) +-- d1W (WThen u w) = WThen (d1W u) (d1W w) + +conv1Idx :: Idx env t -> Idx (D1E env) (D1 t) +conv1Idx IZ = IZ +conv1Idx (IS i) = IS (conv1Idx i) + +data Idx2 env sto t + = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) + | Idx2Me (Idx (D2E (Select env sto "merge")) (D2 t)) + | Idx2Di (Idx (Select env sto "discr") t) + +conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t +conv2Idx (DPush _ (_, _, SAccum)) IZ = Idx2Ac IZ +conv2Idx (DPush _ (_, _, SMerge)) IZ = Idx2Me IZ +conv2Idx (DPush _ (_, _, SDiscr)) IZ = Idx2Di IZ +conv2Idx (DPush des (_, _, SAccum)) (IS i) = + case conv2Idx des i of Idx2Ac j -> Idx2Ac (IS j) + Idx2Me j -> Idx2Me j + Idx2Di j -> Idx2Di j +conv2Idx (DPush des (_, _, SMerge)) (IS i) = + case conv2Idx des i of Idx2Ac j -> Idx2Ac j + Idx2Me j -> Idx2Me (IS j) + Idx2Di j -> Idx2Di j +conv2Idx (DPush des (_, _, SDiscr)) (IS i) = + case conv2Idx des i of Idx2Ac j -> Idx2Ac j + Idx2Me j -> Idx2Me j + Idx2Di j -> Idx2Di (IS j) +conv2Idx DTop i = case i of {} + +opt2UnSparse :: SOp a b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b) +opt2UnSparse = go . opt2 + where + go :: STy b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b) + go (STScal STI32) SpAbsent = \_ -> ENil ext + go (STScal STI64) SpAbsent = \_ -> ENil ext + go (STScal STF32) SpAbsent = \_ -> EZero ext (SMTScal STF32) (ENil ext) + go (STScal STF64) SpAbsent = \_ -> EZero ext (SMTScal STF64) (ENil ext) + go (STScal STBool) SpAbsent = \_ -> ENil ext + go (STScal STF32) SpScal = id + go (STScal STF64) SpScal = id + go STNil _ = \_ -> ENil ext + go (STPair t1 t2) (SpPair s1 s2) = \e -> eunPair e $ \_ e1 e2 -> EPair ext (go t1 s1 e1) (go t2 s2 e2) + go t _ = error $ "Primitive operations that return " ++ show t ++ " are scary" + + +----------------------------------- SPARSITY ----------------------------------- + +expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a) +expandSparse t sp _ e | Just Refl <- isDense (d2M t) sp = e +expandSparse t (SpSparse sp) epr e = + EMaybe ext + (EZero ext (d2M t) (d2zeroInfo t epr)) + (expandSparse t sp (weakenExpr WSink epr) (EVar ext (applySparse sp (d2 t)) IZ)) + e +expandSparse t SpAbsent epr _ = EZero ext (d2M t) (d2zeroInfo t epr) +expandSparse (STPair t1 t2) (SpPair s1 s2) epr e = + eunPair epr $ \w1 epr1 epr2 -> + eunPair (weakenExpr w1 e) $ \w2 e1 e2 -> + EPair ext (expandSparse t1 s1 (weakenExpr w2 epr1) e1) + (expandSparse t2 s2 (weakenExpr w2 epr2) e2) +expandSparse (STEither t1 t2) (SpLEither s1 s2) epr e = + ELCase ext e + (EZero ext (d2M (STEither t1 t2)) (ENil ext)) + (ECase ext (weakenExpr WSink epr) + (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ)))) + (EError ext (d2 (STEither t1 t2)) "expspa r<-dl")) + (ECase ext (weakenExpr WSink epr) + (EError ext (d2 (STEither t1 t2)) "expspa l<-dr") + (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) +expandSparse (STLEither t1 t2) (SpLEither s1 s2) epr e = + ELCase ext e + (EZero ext (d2M (STEither t1 t2)) (ENil ext)) + (ELCase ext (weakenExpr WSink epr) + (EError ext (d2 (STEither t1 t2)) "expspa ln<-dl") + (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ)))) + (EError ext (d2 (STEither t1 t2)) "expspa lr<-dl")) + (ELCase ext (weakenExpr WSink epr) + (EError ext (d2 (STEither t1 t2)) "expspa ln<-dr") + (EError ext (d2 (STEither t1 t2)) "expspa ll<-dr") + (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) +expandSparse (STMaybe t) (SpMaybe s) epr e = + EMaybe ext + (ENothing ext (d2 t)) + (let epr' = EMaybe ext (EError ext (d1 t) "expspa n<-dj") (EVar ext (d1 t) IZ) epr + in EJust ext (expandSparse t s (weakenExpr WSink epr') (EVar ext (applySparse s (d2 t)) IZ))) + e +expandSparse (STArr _ t) (SpArr s) epr e = + ezipWith (expandSparse t s (EVar ext (d1 t) (IS IZ)) (EVar ext (applySparse s (d2 t)) IZ)) epr e +expandSparse (STScal STF32) SpScal _ e = e +expandSparse (STScal STF64) SpScal _ e = e +expandSparse (STAccum{}) _ _ _ = error "accumulators not allowed in source program" + +subenvPlus :: SBool req1 -> SBool req2 + -> SList SMTy env + -> SubenvS env env1 -> SubenvS env env2 + -> (forall env3. SubenvS env env3 + -> Injection req1 (Tup env1) (Tup env3) + -> Injection req2 (Tup env2) (Tup env3) + -> (forall e. Ex e (Tup env1) -> Ex e (Tup env2) -> Ex e (Tup env3)) + -> r) + -> r +-- don't destroy effects! +subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\a b -> use a $ use b $ ENil ext) + +subenvPlus req1 req2 (SCons _ env) (SENo sub1) (SENo sub2) k = + subenvPlus req1 req2 env sub1 sub2 $ \sub3 s31 s32 pl -> + k (SENo sub3) s31 s32 pl + +subenvPlus req1 SF (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k = + subenvPlus req1 SF env sub1 sub2 $ \sub3 minj13 _ pl -> + k (SEYes sp1 sub3) + (withInj minj13 $ \inj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) e1b) + Noinj + (\e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (ESnd ext (EVar ext (typeOf e1) IZ))) +subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k + | Just zero1 <- cheapZero (applySparse sp1 t) = + subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl -> + k (SEYes sp1 sub3) + (withInj minj13 $ \inj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) e1b) + (Inj $ \e2 -> EPair ext (inj23 e2) zero1) + (\e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (ESnd ext (EVar ext (typeOf e1) IZ))) + | otherwise = + subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl -> + k (SEYes (SpSparse sp1) sub3) + (withInj minj13 $ \inj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) (EJust ext e1b)) + (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t)))) + (\e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ)))) + +subenvPlus req1 req2 (SCons t env) sub1@SENo{} sub2@SEYes{} k = + subenvPlus req2 req1 (SCons t env) sub2 sub1 $ \sub3 minj23 minj13 pl -> + k sub3 minj13 minj23 (flip pl) + +subenvPlus req1 req2 (SCons t env) (SEYes sp1 sub1) (SEYes sp2 sub2) k = + subenvPlus req1 req2 env sub1 sub2 $ \sub3 minj13 minj23 pl -> + sparsePlusS req1 req2 t sp1 sp2 $ \sp3 mTinj13 mTinj23 plus -> + k (SEYes sp3 sub3) + (withInj2 minj13 mTinj13 $ \inj13 tinj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) (tinj13 e1b)) + (withInj2 minj23 mTinj23 $ \inj23 tinj23 -> + \e2 -> eunPair e2 $ \_ e2a e2b -> + EPair ext (inj23 e2a) (tinj23 e2b)) + (\e1 e2 -> + ELet ext e1 $ + ELet ext (weakenExpr WSink e2) $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) + (EFst ext (EVar ext (typeOf e2) IZ))) + (plus + (ESnd ext (EVar ext (typeOf e1) (IS IZ))) + (ESnd ext (EVar ext (typeOf e2) IZ)))) + +expandSubenvZeros :: D1E env0 :> env -> SList STy env0 -> SubenvS (D2E env0) contribs + -> Ex env (Tup contribs) -> Ex env (Tup (D2E env0)) +expandSubenvZeros _ SNil SETop _ = ENil ext +expandSubenvZeros w (SCons t ts) (SEYes sp sub) e = + eunPair e $ \w1 e1 e2 -> + EPair ext + (expandSubenvZeros (w1 .> WPop w) ts sub e1) + (expandSparse t sp (EVar ext (d1 t) (w1 .> w @> IZ)) e2) +expandSubenvZeros w (SCons t ts) (SENo sub) e = + EPair ext + (expandSubenvZeros (WPop w) ts sub e) + (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) + + +--------------------------------- ACCUMULATORS --------------------------------- + +fromArrayValId :: Maybe (ValId t) -> Maybe Int +fromArrayValId (Just (VIArr i _)) = Just i +fromArrayValId _ = Nothing + +accumPromote :: forall dt env sto proxy r. + proxy dt + -> Descr env sto + -> (forall stoRepl envPro. + (Select env stoRepl "merge" ~ '[]) + => Descr env stoRepl + -- ^ A revised environment description that switches + -- arrays (used in the OccEnv) that are currently on + -- "merge" storage, to "accum" storage. + -> SList STy envPro + -- ^ New entries on top of the original dual environment, + -- that house the accumulators for the promoted arrays in + -- the original environment. + -> Subenv (Select env sto "merge") envPro + -- ^ The promoted entries were merge entries in the + -- original environment. + -> Subenv (D2AcE (Select env stoRepl "accum")) (D2AcE (Select env sto "accum")) + -- ^ All entries that were accumulators are still + -- accumulators. + -> VarMap Int (D2AcE (Select env stoRepl "accum")) + -- ^ Accumulator map for _only_ the the newly allocated + -- accumulators. + -> (forall shbinds. + SList STy shbinds + -> (dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) + :> Append (D2AcE envPro) (dt : Append shbinds (D2AcE (Select env sto "accum")))) + -- ^ A weakening that converts a computation in the + -- revised environment to one in the original environment + -- extended with some accumulators. + -> r) + -> r +accumPromote _ DTop k = k DTop SNil SETop SETop VarMap.empty (\_ -> WId) +accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of + -- Accumulators are left as-is + SAccum -> + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> + k (storepl `DPush` (t, vid, SAccum)) + envpro + prosub + (SEYesR accrevsub) + (VarMap.sink1 accumMap) + (\shbinds -> + autoWeak (#pro (d2ace envpro) &. #d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr))) + (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) + (#pro :++: #d :++: #shb :++: #acc :++: #tl) + .> WCopy (wf shbinds) + .> autoWeak (#d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl))) + (#d :++: #shb :++: #acc :++: #tl) + (#acc :++: (#d :++: #shb :++: #tl))) + + SMerge -> case t of + -- Discrete values are left as-is + _ | isDiscrete t -> + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap' wf -> + k (storepl `DPush` (t, vid, SDiscr)) + envpro + (SENo prosub) + accrevsub + accumMap' + wf + + -- Values with "merge" storage are promoted to an accumulator in envPro + _ -> + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> + k (storepl `DPush` (t, vid, SAccum)) + (t `SCons` envpro) + (SEYesR prosub) + (SENo accrevsub) + (let accumMap' = VarMap.sink1 accumMap + in case fromArrayValId vid of + Just i -> VarMap.insert i (STAccum (d2M t)) IZ accumMap' + Nothing -> accumMap') + (\(shbinds :: SList _ shbinds) -> + let shbindsC = slistMap (\_ -> Const ()) shbinds + in + -- wf: + -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + -- WCopy wf: + -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + -- WPICK: ^ THESE TWO || + -- goal: | ARE EQUAL || + -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + WCopy (wf shbinds) + .> WPick @(TAccum (D2 t)) @(dt : shbinds) (Const () `SCons` shbindsC) + (WId @(D2AcE (Select env1 stoRepl "accum")))) + + -- Discrete values are left as-is, nothing to do + SDiscr -> + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> + k (storepl `DPush` (t, vid, SDiscr)) + envpro + prosub + accrevsub + accumMap + wf + where + isDiscrete :: STy t' -> Bool + isDiscrete = \case + STNil -> True + STPair a b -> isDiscrete a && isDiscrete b + STEither a b -> isDiscrete a && isDiscrete b + STLEither a b -> isDiscrete a && isDiscrete b + STMaybe a -> isDiscrete a + STArr _ a -> isDiscrete a + STScal st -> case st of + STI32 -> True + STI64 -> True + STF32 -> False + STF64 -> False + STBool -> True + STAccum{} -> False + + +---------------------------- RETURN TRIPLE FROM CHAD --------------------------- + +data Ret env0 sto sd t = + forall shbinds tapebinds contribs. + Ret (Bindings Ex (D1E env0) shbinds) -- shared binds + (Subenv shbinds tapebinds) + (Ex (Append shbinds (D1E env0)) (D1 t)) + (SubenvS (D2E (Select env0 sto "merge")) contribs) + (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)) +deriving instance Show (Ret env0 sto sd t) + +type data TyTyPair = MkTyTyPair Ty Ty + +data SingleRet env0 sto (pair :: TyTyPair) = + forall shbinds tapebinds. + SingleRet + (Bindings Ex (D1E env0) shbinds) -- shared binds + (Subenv shbinds tapebinds) + (RetPair env0 sto (D1E env0) shbinds tapebinds pair) + +-- pattern Ret1 :: forall env0 sto Bindings Ex (D1E env0) shbinds +-- -> Subenv shbinds tapebinds +-- -> Ex (Append shbinds (D1E env0)) (D1 t) +-- -> SubenvS (D2E (Select env0 sto "merge")) contribs +-- -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs) +-- -> SingleRet env0 sto (MkTyTyPair sd t) +-- pattern Ret1 e0 subtape e1 sub e2 = SingleRet e0 subtape (RetPair e1 sub e2) +-- {-# COMPLETE Ret1 #-} + +data RetPair env0 sto env shbinds tapebinds (pair :: TyTyPair) where + RetPair :: forall sd t contribs -- existentials + env0 sto env shbinds tapebinds. -- universals + Ex (Append shbinds env) (D1 t) + -> SubenvS (D2E (Select env0 sto "merge")) contribs + -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs) + -> RetPair env0 sto env shbinds tapebinds (MkTyTyPair sd t) +deriving instance Show (RetPair env0 sto env shbinds tapebinds pair) + +data Rets env0 sto env list = + forall shbinds tapebinds. + Rets (Bindings Ex env shbinds) + (Subenv shbinds tapebinds) + (SList (RetPair env0 sto env shbinds tapebinds) list) +deriving instance Show (Rets env0 sto env list) + +toSingleRet :: Ret env0 sto sd t -> SingleRet env0 sto (MkTyTyPair sd t) +toSingleRet (Ret e0 subtape e1 sub e2) = SingleRet e0 subtape (RetPair e1 sub e2) + +weakenRetPair :: SList STy shbinds -> env :> env' + -> RetPair env0 sto env shbinds tapebinds pair -> RetPair env0 sto env' shbinds tapebinds pair +weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2 + +weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list +weakenRets w (Rets binds tapesub list) = + let (binds', _) = weakenBindingsE w binds + in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list) + +rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto pair f. + Descr env0 sto + -> SList f b1 -> SList f b2 + -> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2 + -> RetPair env0 sto (Append b1 env) b2 tapebinds2 pair + -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) pair +rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair @sd e1 sub e2) + | Refl <- lemAppendAssoc @b2 @b1 @env = + RetPair e1 sub + (weakenExpr (autoWeak + (#d (auto1 @sd) + &. #t2 (subList b2 subtape2) + &. #t1 (subList b1 subtape1) + &. #tl (d2ace (select SAccum descr))) + (#d :++: (#t2 :++: #tl)) + (#d :++: ((#t2 :++: #t1) :++: #tl))) + e2) + +retConcat :: forall env0 sto list. Descr env0 sto -> SList (SingleRet env0 sto) list -> Rets env0 sto (D1E env0) list +retConcat _ SNil = Rets BTop SETop SNil +retConcat descr (SCons (SingleRet (e0 :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) (RetPair e1 sub e2)) list) + | Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs + <- weakenRets (sinkWithBindings e0) (retConcat descr list) + , Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0) + , Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum")) + = Rets (bconcat e0 binds) + (subenvConcat subtape subtape2) + (SCons (RetPair (weakenExpr (sinkWithBindings binds) e1) + sub + (weakenExpr (WCopy (sinkWithSubenv subtape2)) e2)) + (slistMap (rebaseRetPair descr (bindingsBinds e0) (bindingsBinds binds) + subtape subtape2) + pairs)) + +freezeRet :: Descr env sto + -> Ret env sto (D2 t) t + -> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge")))) +freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ _ t) = + let (e0', wInsertD2Ac) = weakenBindingsE (WSink .> wSinks (d2ace (select SAccum descr))) e0 + e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2 + tContribs = tTup (slistMap fromSMTy (subList (d2eM (select SMerge descr)) sub)) + library = #d (auto1 @(D2 t)) + &. #tape (subList (bindingsBinds e0) subtape) + &. #shbinds (bindingsBinds e0) + &. #d2ace (d2ace (select SAccum descr)) + &. #tl (desD1E descr) + &. #contribs (SCons tContribs SNil) + in letBinds e0' $ + EPair ext + (weakenExpr wInsertD2Ac e1) + (ELet ext (weakenExpr (autoWeak library + (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl) + (#shbinds :++: #d :++: #d2ace :++: #tl)) + e2') $ + expandSubenvZeros + (autoWeak library #tl (#contribs :++: #shbinds :++: #d :++: #d2ace :++: #tl) + .> wUndoSubenv (subenvD1E (selectSub SMerge descr))) + (select SMerge descr) sub (EVar ext tContribs IZ)) + + +---------------------------- THE CHAD TRANSFORMATION --------------------------- + +drev :: forall env sto sd t. + (?config :: CHADConfig) + => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) + -> Sparse (D2 t) sd + -> Expr ValId env t -> Ret env sto sd t +drev des _ sd | isAbsent sd = + \e -> + Ret BTop + SETop + (drevPrimal des e) + (subenvNone (d2e (select SMerge des))) + (ENil ext) +drev _ _ SpAbsent = error "Absent should be isAbsent" + +drev des accumMap (SpSparse sd) = + \e -> + case drev des accumMap sd e of { Ret e0 subtape e1 sub e2 -> + subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> + Ret e0 + subtape + e1 + sub' + (emaybe (EVar ext (STMaybe (applySparse sd (d2 (typeOf e)))) IZ) + (inj2 (ENil ext)) + (inj1 (weakenExpr (WCopy WSink) e2))) + } + +drev des accumMap sd = \case + EVar _ t i -> + case conv2Idx des i of + Idx2Ac accI -> + Ret BTop + SETop + (EVar ext (d1 t) (conv1Idx i)) + (subenvNone (d2e (select SMerge des))) + (let ty = applySparse sd (d2M t) + in EAccum ext (d2M t) SAPHere (ENil ext) sd (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) + + Idx2Me tupI -> + Ret BTop + SETop + (EVar ext (d1 t) (conv1Idx i)) + (subenvOnehot (d2e (select SMerge des)) tupI sd) + (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t)) IZ)) + + Idx2Di _ -> + Ret BTop + SETop + (EVar ext (d1 t) (conv1Idx i)) + (subenvNone (d2e (select SMerge des))) + (ENil ext) + + ELet _ (rhs :: Expr _ _ a) body + | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && typeHasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge + , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body + , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) rhs1 subRHS rhs2 <- drev des accumMap sdBody rhs + , let (body0', wbody0') = weakenBindingsE (WCopy (sinkWithBindings rhs0)) body0 + , Refl <- lemAppendAssoc @body_shbinds @'[D1 a] @rhs_shbinds + , Refl <- lemAppendAssoc @body_shbinds @(D1 a : rhs_shbinds) @(D1E env) + , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) + -> + subenvPlus SF SF (d2eM (select SMerge des)) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> + let bodyResType = STPair (contribTupTy des subBody) (applySparse sdBody (d2 (typeOf rhs))) in + Ret (bconcat (rhs0 `bpush` rhs1) body0') + (subenvConcat subtapeRHS subtapeBody) + (weakenExpr wbody0' body1) + subBoth + (ELet ext (weakenExpr (autoWeak (#d (auto1 @sd) + &. #body (subList (bindingsBinds body0 `sappend` SCons (d1 (typeOf rhs)) SNil) subtapeBody) + &. #rhs (subList (bindingsBinds rhs0) subtapeRHS) + &. #tl (d2ace (select SAccum des))) + (#d :++: #body :++: #tl) + (#d :++: (#body :++: #rhs) :++: #tl)) + body2) $ + ELet ext + (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ + weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $ + plus_RHS_Body + (EVar ext (contribTupTy des subRHS) IZ) + (EFst ext (EVar ext bodyResType (IS IZ)))) + + EPair _ a b + | SpPair sd1 sd2 <- sd + , Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) + <- retConcat des $ toSingleRet (drev des accumMap sd1 a) `SCons` toSingleRet (drev des accumMap sd2 b) `SCons` SNil + , let dt = STPair (applySparse sd1 (d2 (typeOf a))) (applySparse sd2 (d2 (typeOf b))) -> + subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B -> + Ret binds + subtape + (EPair ext a1 b1) + subBoth + (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) + (weakenExpr (WCopy WSink) a2)) $ + ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ))) + (weakenExpr (WCopy (WSink .> WSink)) b2)) $ + plus_A_B + (EVar ext (contribTupTy des subA) (IS IZ)) + (EVar ext (contribTupTy des subB) IZ)) + + EFst _ e + | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair sd SpAbsent) e + , STPair t1 _ <- typeOf e -> + Ret e0 + subtape + (EFst ext e1) + sub + (ELet ext (EPair ext (EVar ext (applySparse sd (d2 t1)) IZ) (ENil ext)) $ + weakenExpr (WCopy WSink) e2) + + ESnd _ e + | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair SpAbsent sd) e + , STPair _ t2 <- typeOf e -> + Ret e0 + subtape + (ESnd ext e1) + sub + (ELet ext (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t2)) IZ)) $ + weakenExpr (WCopy WSink) e2) + + -- Don't need to handle ENil, because its cotangent is always absent! + -- ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (d2e (select SMerge des))) (ENil ext) + + EInl _ t2 e + | SpLEither sd1 sd2 <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap sd1 e -> + subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> + Ret e0 + subtape + (EInl ext (d1 t2) e1) + sub' + (ELCase ext + (EVar ext (STLEither (applySparse sd1 (d2 (typeOf e))) (applySparse sd2 (d2 t2))) IZ) + (inj2 $ ENil ext) + (inj1 $ weakenExpr (WCopy WSink) e2) + (EError ext (contribTupTy des sub') "inl<-dinr")) + + EInr _ t1 e + | SpLEither sd1 sd2 <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap sd2 e -> + subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> + Ret e0 + subtape + (EInr ext (d1 t1) e1) + sub' + (ELCase ext + (EVar ext (STLEither (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 (typeOf e)))) IZ) + (inj2 $ ENil ext) + (EError ext (contribTupTy des sub') "inr<-dinl") + (inj1 $ weakenExpr (WCopy WSink) e2)) + + ECase _ e (a :: Expr _ _ t) b + | STEither (t1 :: STy a) (t2 :: STy b) <- typeOf e + , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && typeHasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge + , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && typeHasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge + , let (bindids1, bindids2) = validSplitEither (extOf e) + , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA sd1 a2 + <- drevScoped des accumMap t1 storage1 bindids1 sd a + , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB sd2 b2 + <- drevScoped des accumMap t2 storage2 bindids2 sd b + , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap (SpLEither sd1 sd2) e + , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) + , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) + , let subtapeListA = subList (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA + , let subtapeListB = subList (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB + , let tapeA = tapeTy subtapeListA + , let tapeB = tapeTy subtapeListB + , let collectA = bindingsCollectTape @_ @_ @(Append rhs_a_binds (D1 a : Append e_binds (D1E env))) + (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA + , let collectB = bindingsCollectTape @_ @_ @(Append rhs_b_binds (D1 b : Append e_binds (D1E env))) + (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB + , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB) + , let (a0', wa0') = weakenBindingsE (WCopy (sinkWithBindings e0)) a0 + , let (b0', wb0') = weakenBindingsE (WCopy (sinkWithBindings e0)) b0 + , Refl <- lemAppendNil @(Append rhs_a_binds '[D1 a]) + , Refl <- lemAppendNil @(Append rhs_b_binds '[D1 b]) + , Refl <- lemAppendAssoc @rhs_a_binds @'[D1 a] @(D1E env) + , Refl <- lemAppendAssoc @rhs_b_binds @'[D1 b] @(D1E env) + , let wa0'' = wa0' .> wCopies (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) (WClosed @(D1E env)) + , let wb0'' = wb0' .> wCopies (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) (WClosed @(D1E env)) + -> + subenvPlus ST ST (d2eM (select SMerge des)) subA subB $ \subAB (Inj sAB_A) (Inj sAB_B) _ -> + subenvPlus SF SF (d2eM (select SMerge des)) subAB subE $ \subOut _ _ plus_AB_E -> + Ret (e0 `bpush` ECase ext e1 + (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0'')))) + (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0''))))) + (SEYesR subtapeE) + (EFst ext (EVar ext tPrimal IZ)) + subOut + (elet + (ECase ext (ESnd ext (EVar ext tPrimal (IS IZ))) + (let (rebinds, prerebinds) = reconstructBindings subtapeListA + in letBinds (rebinds IZ) $ + ELet ext + (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_a_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListA prerebinds) @> IS IZ)) $ + elet + (weakenExpr (autoWeak (#d (auto1 @sd) + &. #ta0 subtapeListA + &. #prea0 prerebinds + &. #recon (tapeA `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil) + &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE) + &. #tl (d2ace (select SAccum des))) + (#d :++: #ta0 :++: #tl) + (#d :++: (#ta0 :++: #prea0) :++: #recon :++: #binds :++: #tl)) + a2) $ + EPair ext (sAB_A $ EFst ext (evar IZ)) + (ELInl ext (applySparse sd2 (d2 t2)) (ESnd ext (evar IZ)))) + (let (rebinds, prerebinds) = reconstructBindings subtapeListB + in letBinds (rebinds IZ) $ + ELet ext + (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_b_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListB prerebinds) @> IS IZ)) $ + elet + (weakenExpr (autoWeak (#d (auto1 @sd) + &. #tb0 subtapeListB + &. #preb0 prerebinds + &. #recon (tapeB `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil) + &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE) + &. #tl (d2ace (select SAccum des))) + (#d :++: #tb0 :++: #tl) + (#d :++: (#tb0 :++: #preb0) :++: #recon :++: #binds :++: #tl)) + b2) $ + EPair ext (sAB_B $ EFst ext (evar IZ)) + (ELInr ext (applySparse sd1 (d2 t1)) (ESnd ext (evar IZ))))) $ + plus_AB_E + (EFst ext (evar IZ)) + (ELet ext (ESnd ext (evar IZ)) $ + weakenExpr (WCopy (wSinks' @[_,_,_])) e2)) + + EConst _ t val -> + Ret BTop + SETop + (EConst ext t val) + (subenvNone (d2e (select SMerge des))) + (ENil ext) + + EOp _ op e + | Ret e0 subtape e1 sub e2 <- drev des accumMap (spDense (d2M (opt1 op))) e -> + case d2op op of + Linear d2opfun -> + Ret e0 + subtape + (d1op op e1) + sub + (ELet ext (d2opfun (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ))) + (weakenExpr (WCopy WSink) e2)) + Nonlinear d2opfun -> + Ret (e0 `bpush` e1) + (SEYesR subtape) + (d1op op $ EVar ext (d1 (typeOf e)) IZ) + sub + (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ)) + (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ))) + (weakenExpr (WCopy (wSinks' @[_,_])) e2)) + + ECustom _ _ tb _ srce pr du a b + -- allowed to ignore a2 because 'a' is the part of the input that is inactive + | Ret b0 bsubtape b1 bsub b2 <- drev des accumMap (spDense (d2M tb)) b -> + case isDense (d2M (typeOf srce)) sd of + Just Refl -> + Ret (b0 `bpush` weakenExpr (sinkWithBindings b0) (drevPrimal des a) + `bpush` weakenExpr WSink b1 + `bpush` weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr) + `bpush` ESnd ext (EVar ext (typeOf pr) IZ)) + (SEYesR (SENo (SENo (SENo bsubtape)))) + (EFst ext (EVar ext (typeOf pr) (IS IZ))) + bsub + (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ + weakenExpr (WCopy (WSink .> WSink)) b2) + + Nothing -> + Ret (b0 `bpush` weakenExpr (sinkWithBindings b0) (drevPrimal des a) + `bpush` weakenExpr WSink b1 + `bpush` weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) + (SEYesR (SENo (SENo bsubtape))) + (EFst ext (EVar ext (typeOf pr) IZ)) + bsub + (ELet ext (ESnd ext (EVar ext (typeOf pr) (IS IZ))) $ -- tape + ELet ext (expandSparse (typeOf srce) sd -- expanded incoming cotangent + (EFst ext (EVar ext (typeOf pr) (IS (IS IZ)))) + (EVar ext (applySparse sd (d2 (typeOf srce))) (IS IZ))) $ + ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ + weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) b2) + + ERecompute _ e -> + deleteUnused (descrList des) (occCountAll e) $ \usedSub -> + let smallE = unsafeWeakenWithSubenv usedSub e in + subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> + case drev usedDes (VarMap.subMap subAccumUsed accumMap) sd smallE of { Ret e0 subtape _ sub e2 -> + let subMergeUsed' = subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E subMergeUsed) in + Ret (collectBindings (desD1E des) subD1eUsed) + (subenvAll (desD1E usedDes)) + (weakenExpr (wSinks (desD1E usedDes)) $ drevPrimal des e) + (subenvCompose subMergeUsed' sub) + (letBinds (fst (weakenBindingsE (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $ + weakenExpr + (autoWeak (#d (auto1 @sd) + &. #shbinds (bindingsBinds e0) + &. #tape (subList (bindingsBinds e0) subtape) + &. #d1env (desD1E usedDes) + &. #tl' (d2ace (select SAccum usedDes)) + &. #tl (d2ace (select SAccum des))) + (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: LPreW #tl' #tl (wUndoSubenv subAccumUsed)) + (#shbinds :++: #d :++: #d1env :++: #tl)) + e2) + } + + EError _ t s -> + Ret BTop + SETop + (EError ext (d1 t) s) + (subenvNone (d2e (select SMerge des))) + (ENil ext) + + EConstArr _ n t val -> + Ret BTop + SETop + (EConstArr ext n t val) + (subenvNone (d2e (select SMerge des))) + (ENil ext) + + EBuild _ (ndim :: SNat ndim) she (ef :: Expr _ _ eltty) + | SpArr @_ @sdElt sdElt <- sd + , let eltty = typeOf ef + , shty :: STy shty <- tTup (sreplicate ndim tIx) + , Refl <- indexTupD1Id ndim -> + drevLambda des accumMap (shty, SDiscr) sdElt ef $ \(provars :: SList _ envPro) esub proPrimalBinds e0 e1 (e1tape :: Ex _ e_tape) _ wrapAccum e2 -> + let library = #ix (shty `SCons` SNil) + &. #e0 (bindingsBinds e0) + &. #propr (d1e provars) + &. #d1env (desD1E des) + &. #d (auto1 @sdElt) + &. #tape (auto1 @e_tape) + &. #pro (d2ace provars) + &. #d2acEnv (d2ace (select SAccum des)) + &. #darr (auto1 @(TArr ndim sdElt)) + &. #tapearr (auto1 @(TArr ndim e_tape)) in + Ret (proPrimalBinds + `bpush` weakenExpr (wSinks (d1e provars)) + (EBuild ext ndim + (drevPrimal des she) + (letBinds e0 $ + EPair ext e1 e1tape)) + `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) IZ)) + (SEYesR (SENo (subenvAll (d1e provars)))) + (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) (IS IZ))) + (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) esub) + (let sinkOverEnvPro = wSinks @(sd : TArr ndim e_tape : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace provars) in + ESnd ext $ + wrapAccum (WSink .> WSink .> wRaiseAbove (d1e provars) (d2ace (select SAccum des))) $ + EBuild ext ndim (EShape ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (sinkOverEnvPro @> IZ))) $ + -- the cotangent for this element + ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ)) + (EVar ext shty IZ)) $ + -- the tape for this element + ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) + (EVar ext shty (IS IZ))) $ + weakenExpr (autoWeak library (#tape :++: #d :++: #pro :++: #d2acEnv) + (#tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv)) + e2) + + EMap _ ef (earr :: Expr _ _ (TArr n a)) + | SpArr sdElt <- sd + , let STArr ndim t1 = typeOf earr + t2 = typeOf ef -> + drevLambda des accumMap (t1, SMerge) sdElt ef $ \provars efsub proPrimalBinds ef0 ef1 ef1tape spEf wrapAccum ef2 -> + case drev des accumMap (SpArr spEf) earr of { Ret ea0 easubtape ea1 easub ea2 -> + let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings ea0) proPrimalBinds + ttape = typeOf ef1tape + library = #d1env (desD1E des) + &. #a0 (bindingsBinds ea0) + &. #atapebinds (subList (bindingsBinds ea0) easubtape) + &. #propr (d1e provars) + &. #x (d1 t1 `SCons` SNil) + &. #parr (STArr ndim (d1 t1) `SCons` SNil) + &. #tapearr (STArr ndim ttape `SCons` SNil) + &. #darr (STArr ndim (applySparse sdElt (d2 t2)) `SCons` SNil) + &. #dy (applySparse sdElt (d2 t2) `SCons` SNil) + &. #tape (ttape `SCons` SNil) + &. #dytape (STPair (applySparse sdElt (d2 t2)) ttape `SCons` SNil) + &. #d2acEnv (d2ace (select SAccum des)) + &. #pro (d2ace provars) + in + subenvPlus SF SF (d2eM (select SMerge des)) (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) easub $ \subfa _ _ plus_f_a -> + Ret (bconcat ea0 proPrimalBinds' + `bpush` weakenExpr (autoWeak library (#a0 :++: #d1env) ((#propr :++: #a0) :++: #d1env)) ea1 + `bpush` emap (weakenExpr (autoWeak library (#x :++: #d1env) (#x :++: #parr :++: (#propr :++: #a0) :++: #d1env)) + (letBinds ef0 $ + EPair ext ef1 ef1tape)) + (EVar ext (STArr ndim (d1 t1)) IZ) + `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 t2) ttape)) IZ)) + (SEYesR (SENo (SENo (subenvConcat easubtape (subenvAll (d1e provars)))))) + (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 t2) ttape)) (IS IZ))) + subfa + (let layout = #darr :++: #tapearr :++: (#propr :++: #atapebinds) :++: #d2acEnv in + elet + (wrapAccum (autoWeak library #propr layout) $ + emap (elet (EFst ext (EVar ext (STPair (applySparse sdElt (d2 t2)) ttape) IZ)) $ + elet (ESnd ext (EVar ext (STPair (applySparse sdElt (d2 t2)) ttape) (IS IZ))) $ + weakenExpr (autoWeak library (#tape :++: #dy :++: #pro :++: #d2acEnv) + (#tape :++: #dy :++: #dytape :++: #pro :++: layout)) + ef2) + (ezip (EVar ext (STArr ndim (applySparse sdElt (d2 t2))) (autoWeak library #darr (#pro :++: layout) @> IZ)) + (EVar ext (STArr ndim ttape) (autoWeak library #tapearr (#pro :++: layout) @> IZ)))) $ + plus_f_a + (ESnd ext (evar IZ)) + (weakenExpr (WCopy (autoWeak library (#atapebinds :++: #d2acEnv) layout)) + (subst0 (EFst ext (EVar ext (STPair (STArr ndim (typeOf ef2)) (tTup (d2e provars))) IZ)) + ea2))) + } + + EFold1Inner _ commut origef ex₀ earr + | SpArr @_ @sdElt sdElt <- sd + , STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr + , Rets bindsx₀a subtapex₀a (RetPair ex₀1 subx₀ ex₀2 `SCons` RetPair ea1 suba ea2 `SCons` SNil) + <- retConcat des $ toSingleRet (drev des accumMap (spDense (d2M eltty)) ex₀) `SCons` toSingleRet (drev des accumMap (spDense (SMTArr (SS ndim) (d2M eltty))) earr) `SCons` SNil -> + drevLambda des accumMap (STPair eltty eltty, SMerge) (spDense (d2M eltty)) origef $ \(provars :: SList _ envPro) efsub proPrimalBinds ef0 ef1 (ef1tape :: Ex _ ef_tape) spEf wrapAccum ef2 -> + let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings bindsx₀a) proPrimalBinds in + let bogEltTy = STPair (STPair (d1 eltty) (d1 eltty)) (typeOf ef1tape) + bogTy = STArr (SS ndim) bogEltTy + primalTy = STPair (STArr ndim (d1 eltty)) bogTy + library = #xy (STPair (d1 eltty) (d1 eltty) `SCons` SNil) + &. #parr (auto1 @(TArr (S n) (D1 elt))) + &. #px₀ (auto1 @(D1 elt)) + &. #px (auto1 @(D1 elt)) + &. #pzi (auto1 @(ZeroInfo (D2 elt))) + &. #primal (primalTy `SCons` SNil) + &. #darr (auto1 @(TArr n sdElt)) + &. #d (auto1 @(D2 elt)) + &. #x₀abinds (bindingsBinds bindsx₀a) + &. #fbinds (bindingsBinds ef0) + &. #x₀atapebinds (subList (bindingsBinds bindsx₀a) subtapex₀a) + &. #ftape (auto1 @ef_tape) + &. #bogelt (bogEltTy `SCons` SNil) + &. #propr (d1e provars) + &. #d1env (desD1E des) + &. #d2acEnv (d2ace (select SAccum des)) + &. #d2acPro (d2ace provars) + &. #foldd2res (auto1 @(TPair (TPair (D2 elt) (TArr (S n) (D2 elt))) (Tup (D2E envPro)))) + wOverPrimalBindings = autoWeak library (#x₀abinds :++: #d1env) ((#propr :++: #x₀abinds) :++: #d1env) in + subenvPlus SF SF (d2eM (select SMerge des)) subx₀ suba $ \subx₀a _ _ plus_x₀_a -> + subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) $ \subx₀af _ _ plus_x₀a_f -> + Ret (bconcat bindsx₀a proPrimalBinds' + `bpush` weakenExpr wOverPrimalBindings ex₀1 + `bpush` d2zeroInfo eltty (EVar ext (d1 eltty) IZ) + `bpush` weakenExpr (WSink .> WSink .> wOverPrimalBindings) ea1 + `bpush` EFold1InnerD1 ext commut + (let layout = #xy :++: #parr :++: #pzi :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env in + weakenExpr (autoWeak library (#xy :++: #d1env) layout) + (letBinds ef0 $ + EPair ext -- (out, ((in1, in2), tape)); the "additional stores" are ((in1, in2), tape) + ef1 + (EPair ext + (EVar ext (STPair (d1 eltty) (d1 eltty)) (autoWeak library #xy (#fbinds :++: #xy :++: #d1env) @> IZ)) + ef1tape))) + (EVar ext (d1 eltty) (IS (IS IZ))) + (EVar ext (STArr (SS ndim) (d1 eltty)) IZ)) + (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e provars))))))) + (EFst ext (EVar ext primalTy IZ)) + subx₀af + (let layout1 = #darr :++: #primal :++: #parr :++: #pzi :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in + elet + (wrapAccum (autoWeak library #propr layout1) $ + let layout2 = #d2acPro :++: layout1 in + EFold1InnerD2 ext commut + (elet (ESnd ext (EVar ext bogEltTy (IS IZ))) $ + let layout3 = #ftape :++: #d :++: #bogelt :++: layout2 in + expandSparse (STPair eltty eltty) spEf (EFst ext (EVar ext bogEltTy (IS (IS IZ)))) $ + weakenExpr (autoWeak library (#ftape :++: #d :++: #d2acPro :++: #d2acEnv) layout3) ef2) + (ESnd ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))) + (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ))) + (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ)) + (EFst ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))))) $ + plus_x₀a_f + (plus_x₀_a + (elet (EIdx0 ext + (EFold1Inner ext Commut + (let t = STPair (d2 eltty) (d2 eltty) + in EPlus ext (d2M eltty) (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ))) + (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (WSink .> autoWeak library #pzi layout1 @> IZ))) + (eflatten (EFst ext (EFst ext (evar IZ)))))) $ + weakenExpr (WCopy (WSink .> autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) + ex₀2) + (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $ + subst0 (ESnd ext (EFst ext (evar IZ))) ea2)) + (ESnd ext (evar IZ))) + + EUnit _ e + | SpArr sdElt <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap sdElt e -> + Ret e0 + subtape + (EUnit ext e1) + sub + (ELet ext (EIdx0 ext (EVar ext (STArr SZ (applySparse sdElt (d2 (typeOf e)))) IZ)) $ + weakenExpr (WCopy WSink) e2) + + EReplicate1Inner _ en e + -- We're allowed to differentiate 'en' as primal-only here because its output is discrete. + | SpArr sdElt <- sd + , let STArr ndim eltty = typeOf e -> + -- This pessimistic sparsity union is because the array might have been empty, in which case we need to generate a zero. + sparsePlusS ST ST (d2M eltty) sdElt SpAbsent $ \sdElt' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sdElt') e of { Ret binds subtape e1 sub e2 -> + Ret binds + subtape + (EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1) + sub + (ELet ext (EFold1Inner ext Commut + (let t = STPair (applySparse sdElt' (d2 eltty)) (applySparse sdElt' (d2 eltty)) + in sparsePlus (d2M eltty) sdElt' (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ))) + (inj2 (ENil ext)) + (emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $ + weakenExpr (WCopy WSink) e2) + } + + EIdx0 _ e + | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd) e + , STArr _ t <- typeOf e -> + Ret e0 + subtape + (EIdx0 ext e1) + sub + (ELet ext (EUnit ext (EVar ext (applySparse sd (d2 t)) IZ)) $ + weakenExpr (WCopy WSink) e2) + + EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead" + {- + EIdx1 _ e ei + -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. + | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) + <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil + , STArr (SS n) eltty <- typeOf e -> + Ret (binds `bpush` e1 + `bpush` EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ)) + (SEYesR (SENo subtape)) + (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ)) + (weakenExpr (WSink .> WSink) ei1)) + sub + (ELet ext (ebuildUp1 n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) + (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) + (EVar ext (STArr n (d2 eltty)) (IS IZ))) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + -} + + EIdx _ e ei + -- We're allowed to differentiate ei as primal because its output is discrete. + | STArr n eltty <- typeOf e + , Refl <- indexTupD1Id n + , let tIxN = tTup (sreplicate n tIx) -> + sparsePlusS ST ST (d2M eltty) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sd') e of { Ret binds subtape e1 sub e2 -> + Ret (binds `bpush` e1 + `bpush` EShape ext (EVar ext (typeOf e1) IZ) + `bpush` weakenExpr (WSink .> WSink .> wSinks (bindingsBinds binds)) (drevPrimal des ei)) + (SEYesR (SEYesR (SENo subtape))) + (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) IZ)) + sub + (ELet ext + (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty))) + (SAPArrIdx SAPHere) + (EPair ext + (EPair ext (EVar ext tIxN (IS IZ)) + (EBuild ext n (EVar ext tIxN (IS (IS IZ))) $ + makeZeroInfo (applySparse sd' (d2M eltty)) (inj2 (ENil ext)))) + (ENil ext)) + (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + } + + EShape _ e + -- Allowed to differentiate e as primal because the output of EShape is + -- discrete, hence we'd be passing a zero cotangent to e anyway. + | STArr n _ <- typeOf e + , Refl <- indexTupD1Id n -> + Ret BTop + SETop + (EShape ext (drevPrimal des e)) + (subenvNone (d2eM (select SMerge des))) + (ENil ext) + + ESum1Inner _ e + | SpArr sd' <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e + , STArr (SS n) t <- typeOf e -> + Ret (e0 `bpush` e1 + `bpush` EShape ext (EVar ext (STArr (SS n) t) IZ)) + (SEYesR (SENo subtape)) + (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ))) + sub + (ELet ext (EReplicate1Inner ext + (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) + (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + + EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e + EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e + + EReshape _ n esh e + | SpArr sd' <- sd + , STArr orign t <- typeOf e + , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e + , Refl <- indexTupD1Id n -> + Ret (e0 `bpush` e1 + `bpush` EShape ext (EVar ext (STArr orign (d1 t)) IZ)) + (SEYesR (SENo subtape)) + (EReshape ext n (weakenExpr (WSink .> WSink .> wSinks (bindingsBinds e0)) (drevPrimal des esh)) + (EVar ext (STArr orign (d1 t)) (IS IZ))) + sub + (elet (EReshape ext orign (EVar ext (tTup (sreplicate orign tIx)) (IS IZ)) + (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + + EZip _ a b + | SpArr sd' <- sd + , STArr n t1 <- typeOf a + , STArr _ t2 <- typeOf b -> + splitSparsePair (STPair (d2 t1) (d2 t2)) sd' $ \sd1 sd2 pairSplitE -> + case retConcat des (toSingleRet (drev des accumMap (SpArr sd1) a) `SCons` + toSingleRet (drev des accumMap (SpArr sd2) b) `SCons` SNil) of + { Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) -> + subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B -> + Ret binds + subtape + (EZip ext a1 b1) + subBoth + (case pairSplitE of + Left Refl -> + let t' = STArr n (STPair (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 t2))) in + plus_A_B + (elet (emap (EFst ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) a2) + (elet (emap (ESnd ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) b2) + Right f -> f IZ $ \wrapPair pick1 pick2 -> + elet (emap (wrapPair (EPair ext pick1 pick2)) + (EVar ext (applySparse (SpArr sd') (STArr n (STPair (d2 t1) (d2 t2)))) IZ)) $ + plus_A_B + (elet (emap (EFst ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) a2) + (elet (emap (ESnd ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) b2)) + } + + ENothing{} -> err_unsupported "ENothing" + EJust{} -> err_unsupported "EJust" + EMaybe{} -> err_unsupported "EMaybe" + ELNil{} -> err_unsupported "ELNil" + ELInl{} -> err_unsupported "ELInl" + ELInr{} -> err_unsupported "ELInr" + ELCase{} -> err_unsupported "ELCase" + + EWith{} -> err_accum + EZero{} -> err_monoid + EDeepZero{} -> err_monoid + EPlus{} -> err_monoid + EOneHot{} -> err_monoid + + EFold1InnerD1{} -> err_targetlang "EFold1InnerD1" + EFold1InnerD2{} -> err_targetlang "EFold1InnerD2" + + where + err_accum = error "Accumulator operations unsupported in the source program" + err_monoid = error "Monoid operations unsupported in the source program" + err_unsupported s = error $ "CHAD: unsupported " ++ s + err_targetlang s = error $ "CHAD: Target language operation " ++ s ++ " not supported in source program" + + contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs) + contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub)) + +deriv_extremum :: (?config :: CHADConfig, ScalIsNumeric t ~ True) + => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) + -> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) + -> Sparse (D2s t) sd + -> Expr ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t)) +deriv_extremum extremum des accumMap sd e + | at@(STArr (SS n) t@(STScal st)) <- typeOf e + , let at' = STArr n t + , let tIxN = tTup (sreplicate (SS n) tIx) = + sparsePlusS ST ST (d2M t) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sd') e of { Ret e0 subtape e1 sub e2 -> + Ret (e0 `bpush` e1 + `bpush` extremum (EVar ext at IZ)) + (SEYesR (SEYesR subtape)) + (EVar ext at' IZ) + sub + (ELet ext + (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $ + eif (EOp ext (OEq st) (EPair ext + (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ)) + (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ))))) + (inj1 $ EIdx ext (EVar ext (STArr n (applySparse sd (d2 t))) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) + (inj2 (ENil ext))) $ + weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + } + +data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s) + +data RetScoped env0 sto a s sd t = + forall shbinds tapebinds contribs sa. + RetScoped + (Bindings Ex (D1E (a : env0)) shbinds) -- shared binds + (Subenv (Append shbinds '[D1 a]) tapebinds) + (Ex (Append shbinds (D1E (a : env0))) (D1 t)) + (SubenvS (D2E (Select env0 sto "merge")) contribs) + -- ^ merge contributions to the _enclosing_ merge environment + (Sparse (D2 a) sa) + -- ^ contribution to the argument + (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) + (If (s == "discr") (Tup contribs) + (TPair (Tup contribs) sa))) + -- ^ the merge contributions, plus the cotangent to the argument + -- (if there is any) +deriving instance Show (RetScoped env0 sto a s sd t) + +drevScoped :: forall a s env sto sd t. + (?config :: CHADConfig) + => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) + -> STy a -> Storage s -> Maybe (ValId a) + -> Sparse (D2 t) sd + -> Expr ValId (a : env) t + -> RetScoped env sto a s sd t +drevScoped des accumMap argty argsto argids sd expr = case argsto of + SMerge + | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr + , Refl <- lemAppendNil @tapebinds -> + case sub of + SEYes sp sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' sp e2 + SENo sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' SpAbsent (EPair ext e2 (ENil ext)) + + SAccum + | chcSmartWith ?config + , Just (VIArr i _) <- argids + , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap + , Just Refl <- testEquality foundTy (STAccum (d2M argty)) + , Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) sd expr + , Refl <- lemAppendNil @tapebinds -> + -- Our contribution to the binding's cotangent _here_ is zero (absent), + -- because we're contributing to an earlier binding of the same value + -- instead. + RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent $ + let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in + ELet ext (EVar ext (STAccum (d2M argty)) (WSink .> wtapebinds @> idx)) $ + weakenExpr (autoWeak (#d (auto1 @sd) + &. #body (subList (bindingsBinds e0) subtape) + &. #ac (auto1 @(TAccum (D2 a))) + &. #tl (d2ace (select SAccum des))) + (#d :++: #body :++: #ac :++: #tl) + (#ac :++: #d :++: #body :++: #tl)) + (EPair ext e2 (ENil ext)) + + | let accumMap' = case argids of + Just (VIArr i _) -> VarMap.insert i (STAccum (d2M argty)) IZ (VarMap.sink1 accumMap) + _ -> VarMap.sink1 accumMap + , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' sd expr -> + let library = #d (auto1 @sd) + &. #p (auto1 @(D1 a)) + &. #body (subList (bindingsBinds e0) subtape) + &. #ac (auto1 @(TAccum (D2 a))) + &. #tl (d2ace (select SAccum des)) + in + RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub (spDense (d2M argty)) $ + let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in + EWith ext (d2M argty) (EDeepZero ext (d2M argty) (d2deepZeroInfo argty (EVar ext (d1 argty) primalIdx))) $ + weakenExpr (autoWeak library + (#d :++: #body :++: #ac :++: #tl) + (#ac :++: #d :++: (#body :++: #p) :++: #tl)) + e2 + + SDiscr + | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr + , Refl <- lemAppendNil @tapebinds -> + RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2 + +drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False) + => Descr env sto + -> VarMap Int (D2AcE (Select env sto "accum")) + -> (STy a, Storage s) + -> Sparse (D2 t) dt + -> Expr ValId (a : env) t + -> (forall provars shbinds tape d2a'. + SList STy provars + -> Subenv (D2E (Select env sto "merge")) (D2E provars) + -> Bindings Ex (D1E env) (D1E provars) -- accum-promoted free variables of which we need a primal in the reverse pass (to initialise the accumulator) + -> Bindings Ex (D1 a : D1E env) shbinds + -> Ex (Append shbinds (D1 a : D1E env)) (D1 t) + -> Ex (Append shbinds (D1 a : D1E env)) tape + -> Sparse (D2 a) d2a' + -> (forall env' b. + D1E provars :> env' + -> Ex (Append (D2AcE provars) env') b + -> Ex ( env') (TPair b (Tup (D2E provars)))) + -> Ex (tape : dt : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a' + -> r) + -> r +drevLambda des accumMap (argty, argsto) sd origef k = + let t = typeOf origef in + deleteUnused (descrList des) (occEnvPopSome (occCountAll origef)) $ \(usedSub :: Subenv env env') -> + let ef = unsafeWeakenWithSubenv (SEYesR usedSub) origef in + subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> + accumPromote (applySparse sd (d2 t)) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> + let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in + let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in + let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in + case prf1 prodes argty argsto of { Refl -> + case drev (prodes `DPush` (argty, Nothing, argsto)) accumMapPro sd ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 -> + let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in + extractContrib prodes argty argsto subEf $ \argSp getSparseArg -> + let library = #fbinds (bindingsBinds ef0) + &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf) + &. #ftape (auto1 @(Tape e_tape)) + &. #arg (d1 argty `SCons` SNil) + &. #d (applySparse sd (d2 t) `SCons` SNil) + &. #d1env (desD1E des) + &. #d1env' (desD1E usedDes) + &. #propr (d1e envPro) + &. #d2acUsed (d2ace (select SAccum usedDes)) + &. #d2acEnv (d2ace (select SAccum des)) + &. #d2acPro (d2ace envPro) + &. #efPrerebinds efPrerebinds in + k envPro + (subenvD2E (subenvCompose subMergeUsed proSub)) + mergePrimalBindings + (fst (weakenBindingsE (WCopy (wUndoSubenv subD1eUsed)) ef0)) + (weakenExpr (autoWeak library (#fbinds :++: #arg :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#fbinds :++: #arg :++: #d1env)) + ef1) + (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: #arg :++: #d1env))) + argSp + (\wpro1 body -> + uninvertTup (d2e envPro) (typeOf body) $ + makeAccumulators wpro1 envPro $ + body) + (letBinds (efRebinds IZ) $ + weakenExpr + (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) + ((#ftapebinds :++: #efPrerebinds) :++: #ftape :++: #d :++: #d2acPro :++: #d2acEnv) + .> wPro (subList (bindingsBinds ef0) subtapeEf)) + (getSparseArg ef2)) + }} + where + extractContrib :: (Select env sto "merge" ~ '[], (s == "accum") ~ False) + => proxy env sto -> proxy2 a -> Storage s + -- if s == "merge", this simplifies to SubenvS '[D2 a] t' + -- if s == "discr", this simplifies to SubenvS '[] t' + -> SubenvS (D2E (Select (a : env) (s : sto) "merge")) t' + -> (forall d'. Sparse (D2 a) d' -> (forall env'. Ex env' (Tup t') -> Ex env' d') -> r) -> r + extractContrib _ _ SMerge (SENo SETop) k' = k' SpAbsent id + extractContrib _ _ SMerge (SEYes s SETop) k' = k' s (ESnd ext) + extractContrib _ _ SDiscr SETop k' = k' SpAbsent id + + prf1 :: (s == "accum") ~ False => proxy env sto -> proxy2 a -> Storage s + -> Select (a : env) (s : sto) "accum" :~: Select env sto "accum" + prf1 _ _ SMerge = Refl + prf1 _ _ SDiscr = Refl + +-- TODO: proper primal-only transform that doesn't depend on D1 = Id +drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t) +drevPrimal des e + | Refl <- d1Identity (typeOf e) + , Refl <- d1eIdentity (descrList des) + = mapExt (const ext) e |
