diff options
Diffstat (limited to 'src/CHAD.hs')
| -rw-r--r-- | src/CHAD.hs | 1131 |
1 files changed, 0 insertions, 1131 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs deleted file mode 100644 index 1126fde..0000000 --- a/src/CHAD.hs +++ /dev/null @@ -1,1131 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE ImplicitParams #-} -{-# LANGUAGE OverloadedLabels #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} - --- I want to bring various type variables in scope using type annotations in --- patterns, but I don't want to have to mention all the other type parameters --- of the types in question as well then. Partial type signatures (with '_') are --- useful here. -{-# LANGUAGE PartialTypeSignatures #-} -{-# OPTIONS -Wno-partial-type-signatures #-} -module CHAD ( - drev, - freezeRet, - CHADConfig(..), - defaultConfig, - Storage(..), - Descr(..), - Select, -) where - -import Data.Functor.Const -import Data.Some -import Data.Type.Bool (If) -import Data.Type.Equality (type (==), testEquality) -import GHC.Stack (HasCallStack) - -import Analysis.Identity (ValId(..), validSplitEither) -import AST -import AST.Bindings -import AST.Count -import AST.Env -import AST.Weaken.Auto -import CHAD.Accum -import CHAD.EnvDescr -import CHAD.Types -import Data -import qualified Data.VarMap as VarMap -import Data.VarMap (VarMap) -import Lemmas - - ------------------------------- TAPES AND BINDINGS ------------------------------ - -type family Tape binds where - Tape '[] = TNil - Tape (t : ts) = TPair t (Tape ts) - -tapeTy :: SList STy binds -> STy (Tape binds) -tapeTy SNil = STNil -tapeTy (SCons t ts) = STPair t (tapeTy ts) - -bindingsCollect :: Bindings f env binds -> Subenv binds tapebinds - -> Append binds env :> env2 -> Ex env2 (Tape tapebinds) -bindingsCollect BTop SETop _ = ENil ext -bindingsCollect (BPush binds (t, _)) (SEYes sub) w = - EPair ext (EVar ext t (w @> IZ)) - (bindingsCollect binds sub (w .> WSink)) -bindingsCollect (BPush binds _) (SENo sub) w = - bindingsCollect binds sub (w .> WSink) - --- In order from large to small: i.e. in reverse order from what we want, --- because in a Bindings, the head of the list is the bottom-most entry. -type family TapeUnfoldings binds where - TapeUnfoldings '[] = '[] - TapeUnfoldings (t : ts) = Tape ts : TapeUnfoldings ts - -type family Reverse l where - Reverse '[] = '[] - Reverse (t : ts) = Append (Reverse ts) '[t] - --- An expression that is always 'snd' -data UnfExpr env t where - UnfExSnd :: STy s -> STy t -> UnfExpr (TPair s t : env) t - -fromUnfExpr :: UnfExpr env t -> Ex env t -fromUnfExpr (UnfExSnd s t) = ESnd ext (EVar ext (STPair s t) IZ) - --- - A bunch of 'snd' expressions taking us from knowing that there's a --- 'Tape ts' in the environment (for simplicity assume it's at IZ, we'll fix --- this in reconstructBindings), to having 'Reverse (TapeUnfoldings ts)' in --- the environment. --- - In the extended environment, another bunch of let bindings (these are --- 'fst' expressions, but no need to know that statically) that project the --- fsts out of what we introduced above, one for each type in 'ts'. -data Reconstructor env ts = - Reconstructor - (Bindings UnfExpr (Tape ts : env) (Reverse (TapeUnfoldings ts))) - (Bindings Ex (Append (Reverse (TapeUnfoldings ts)) (Tape ts : env)) ts) - -ssnoc :: SList f ts -> f t -> SList f (Append ts '[t]) -ssnoc SNil a = SCons a SNil -ssnoc (SCons t ts) a = SCons t (ssnoc ts a) - -sreverse :: SList f ts -> SList f (Reverse ts) -sreverse SNil = SNil -sreverse (SCons t ts) = ssnoc (sreverse ts) t - -stapeUnfoldings :: SList STy ts -> SList STy (TapeUnfoldings ts) -stapeUnfoldings SNil = SNil -stapeUnfoldings (SCons _ ts) = SCons (tapeTy ts) (stapeUnfoldings ts) - --- Puts a 'snd' at the top of an unfolder stack and grows the context variable by one. -shiftUnfolder - :: STy t - -> SList STy ts - -> Bindings UnfExpr (Tape ts : env) list - -> Bindings UnfExpr (Tape (t : ts) : env) (Append list '[Tape ts]) -shiftUnfolder newTy ts BTop = BPush BTop (tapeTy ts, UnfExSnd newTy (tapeTy ts)) -shiftUnfolder newTy ts (BPush b (t, UnfExSnd itemTy _)) = - -- Recurse on 'b', and retype the 'snd'. We need to unfold 'b' once in order - -- to expand an 'Append' in the types so that things simplify just enough. - -- We have an equality 'Append binds x1 ~ a : x2', where 'binds' is the list - -- of bindings produced by 'b'. We want to conclude from this that - -- 'binds ~ a : x3' for some 'x3', but GHC will only do that once we know - -- that 'binds ~ y : ys' so that the 'Append' can expand one step, after - -- which 'y ~ a' as desired. The 'case' unfolds 'b' one step. - BPush (shiftUnfolder newTy ts b) (t, case b of BTop -> UnfExSnd itemTy t - BPush{} -> UnfExSnd itemTy t) - -growRecon :: forall env t ts. STy t -> SList STy ts -> Reconstructor env ts -> Reconstructor env (t : ts) -growRecon t ts (Reconstructor unfbs bs) - | Refl <- lemAppendNil @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts]) - , Refl <- lemAppendAssoc @ts @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts]) @(Tape (t : ts) : env) - , Refl <- lemAppendAssoc @(Reverse (TapeUnfoldings ts)) @'[Tape ts] @env - = Reconstructor - (shiftUnfolder t ts unfbs) - -- Add a 'fst' at the bottom of the builder stack. - -- First we have to weaken most of 'bs' to skip one more binding in the - -- unfolder stack above it. - (BPush (fst (weakenBindings weakenExpr - (wCopies (sappend (sreverse (stapeUnfoldings ts)) (SCons (tapeTy ts) SNil)) - (WSink :: env :> (Tape (t : ts) : env))) bs)) - (t - ,EFst ext $ EVar ext (tapeTy (SCons t ts)) $ - wSinks @(Tape (t : ts) : env) - (sappend ts - (sappend (sappend (sreverse (stapeUnfoldings ts)) - (SCons (tapeTy ts) SNil)) - SNil)) - @> IZ)) - -buildReconstructor :: SList STy ts -> Reconstructor env ts -buildReconstructor SNil = Reconstructor BTop BTop -buildReconstructor (SCons t ts) = growRecon t ts (buildReconstructor ts) - --- STRATEGY FOR reconstructBindings --- --- binds = [] --- e : () --- --- binds = [c] --- e : (c, ()) --- x0 = snd x1 : () --- y1 = fst e : c --- --- binds = [b, c] --- e : (b, (c, ())) --- x1 = snd e : (c, ()) --- x0 = snd x1 : () --- y1 = fst x1 : c --- y2 = fst x2 : b --- --- binds = [a, b, c] --- e : (a, (b, (c, ()))) --- x2 = snd e : (b, (c, ())) --- x1 = snd x2 : (c, ()) --- x0 = snd x1 : () --- y1 = fst x1 : c --- y2 = fst x2 : b --- y3 = fst x3 : a - --- Given that in 'env' we can find a 'Tape binds', i.e. a tuple containing all --- the things in the list 'binds', we want to create a let stack that extracts --- all values from that tuple and in effect "restores" the environment --- described by 'binds'. The idea is that elsewhere, we took a slice of the --- environment and saved it all in a tuple to be restored later. We --- incidentally also add a bunch of additional bindings, namely 'Reverse --- (TapeUnfoldings binds)', so the calling code just has to skip those in --- whatever it wants to do. -reconstructBindings :: SList STy binds -> Idx env (Tape binds) - -> (Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds))) - ,SList STy (Reverse (TapeUnfoldings binds))) -reconstructBindings binds tape = - let Reconstructor unf build = buildReconstructor binds - in (fst $ weakenBindings weakenExpr (WIdx tape) - (bconcat (mapBindings fromUnfExpr unf) build) - ,sreverse (stapeUnfoldings binds)) - - ----------------------------------- DERIVATIVES --------------------------------- - -d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t) -d1op (OAdd t) e = EOp ext (OAdd t) e -d1op (OMul t) e = EOp ext (OMul t) e -d1op (ONeg t) e = EOp ext (ONeg t) e -d1op (OLt t) e = EOp ext (OLt t) e -d1op (OLe t) e = EOp ext (OLe t) e -d1op (OEq t) e = EOp ext (OEq t) e -d1op ONot e = EOp ext ONot e -d1op OAnd e = EOp ext OAnd e -d1op OOr e = EOp ext OOr e -d1op OIf e = EOp ext OIf e -d1op ORound64 e = EOp ext ORound64 e -d1op OToFl64 e = EOp ext OToFl64 e -d1op (ORecip t) e = EOp ext (ORecip t) e -d1op (OExp t) e = EOp ext (OExp t) e -d1op (OLog t) e = EOp ext (OLog t) e -d1op (OIDiv t) e = EOp ext (OIDiv t) e -d1op (OMod t) e = EOp ext (OMod t) e - --- | Both primal and dual must be duplicable expressions -data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) - | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a)) - -d2op :: SOp a t -> D2Op a t -d2op op = case op of - OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EJust ext (EPair ext d d) - OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> - EJust ext (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) - (EOp ext (OMul t) (EPair ext (EFst ext e) d))) - ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d - OLt t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) - OLe t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) - OEq t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) - ONot -> Linear $ \_ -> ENil ext - OAnd -> Linear $ \_ -> ENothing ext (STPair STNil STNil) - OOr -> Linear $ \_ -> ENothing ext (STPair STNil STNil) - OIf -> Linear $ \_ -> ENil ext - ORound64 -> Linear $ \_ -> EConst ext STF64 0.0 - OToFl64 -> Linear $ \_ -> ENil ext - ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d) - OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d) - OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d) - OIDiv t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil) - OMod t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil) - where - d2opUnArrangeInt :: SScalTy a - -> (D2s a ~ TScal a => D2Op (TScal a) t) - -> D2Op (TScal a) t - d2opUnArrangeInt ty float = case ty of - STI32 -> Linear $ \_ -> ENil ext - STI64 -> Linear $ \_ -> ENil ext - STF32 -> float - STF64 -> float - STBool -> Linear $ \_ -> ENil ext - - d2opBinArrangeInt :: SScalTy a - -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t) - -> D2Op (TPair (TScal a) (TScal a)) t - d2opBinArrangeInt ty float = case ty of - STI32 -> Linear $ \_ -> ENothing ext (STPair STNil STNil) - STI64 -> Linear $ \_ -> ENothing ext (STPair STNil STNil) - STF32 -> float - STF64 -> float - STBool -> Linear $ \_ -> ENothing ext (STPair STNil STNil) - - floatingD2 :: ScalIsFloating a ~ True - => SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r - floatingD2 STF32 k = k - floatingD2 STF64 k = k - - integralD2 :: ScalIsIntegral a ~ True - => SScalTy a -> ((D2s a ~ TNil, ScalIsNumeric a ~ True) => r) -> r - integralD2 STI32 k = k - integralD2 STI64 k = k - -desD1E :: Descr env sto -> SList STy (D1E env) -desD1E = d1e . descrList - --- d1W :: env :> env' -> D1E env :> D1E env' --- d1W WId = WId --- d1W WSink = WSink --- d1W (WCopy w) = WCopy (d1W w) --- d1W (WPop w) = WPop (d1W w) --- d1W (WThen u w) = WThen (d1W u) (d1W w) - -conv1Idx :: Idx env t -> Idx (D1E env) (D1 t) -conv1Idx IZ = IZ -conv1Idx (IS i) = IS (conv1Idx i) - -data Idx2 env sto t - = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum t)) - | Idx2Me (Idx (Select env sto "merge") t) - | Idx2Di (Idx (Select env sto "discr") t) - -conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t -conv2Idx (DPush _ (_, _, SAccum)) IZ = Idx2Ac IZ -conv2Idx (DPush _ (_, _, SMerge)) IZ = Idx2Me IZ -conv2Idx (DPush _ (_, _, SDiscr)) IZ = Idx2Di IZ -conv2Idx (DPush des (_, _, SAccum)) (IS i) = - case conv2Idx des i of Idx2Ac j -> Idx2Ac (IS j) - Idx2Me j -> Idx2Me j - Idx2Di j -> Idx2Di j -conv2Idx (DPush des (_, _, SMerge)) (IS i) = - case conv2Idx des i of Idx2Ac j -> Idx2Ac j - Idx2Me j -> Idx2Me (IS j) - Idx2Di j -> Idx2Di j -conv2Idx (DPush des (_, _, SDiscr)) (IS i) = - case conv2Idx des i of Idx2Ac j -> Idx2Ac j - Idx2Me j -> Idx2Me j - Idx2Di j -> Idx2Di (IS j) -conv2Idx DTop i = case i of {} - - ------------------------------------- MONOIDS ----------------------------------- - -zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0)) -zeroTup SNil = ENil ext -zeroTup (SCons t env) = EPair ext (zeroTup env) (EZero ext t) - - ------------------------------------- SUBENVS ----------------------------------- - -subenvPlus :: SList STy env - -> Subenv env env1 -> Subenv env env2 - -> (forall env3. Subenv env env3 - -> Subenv env3 env1 - -> Subenv env3 env2 - -> (Ex exenv (Tup (D2E env1)) - -> Ex exenv (Tup (D2E env2)) - -> Ex exenv (Tup (D2E env3))) - -> r) - -> r -subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext) -subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SENo sub3) s31 s32 pl -subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 -> - ELet ext e1 $ - EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) - (weakenExpr WSink e2)) - (ESnd ext (EVar ext (typeOf e1) IZ)) -subenvPlus (SCons _ env) (SENo sub1) (SEYes sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sub3) (SENo s31) (SEYes s32) $ \e1 e2 -> - ELet ext e2 $ - EPair ext (pl (weakenExpr WSink e1) - (EFst ext (EVar ext (typeOf e2) IZ))) - (ESnd ext (EVar ext (typeOf e2) IZ)) -subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sub3) (SEYes s31) (SEYes s32) $ \e1 e2 -> - ELet ext e1 $ - ELet ext (weakenExpr WSink e2) $ - EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) - (EFst ext (EVar ext (typeOf e2) IZ))) - (EPlus ext t - (ESnd ext (EVar ext (typeOf e1) (IS IZ))) - (ESnd ext (EVar ext (typeOf e2) IZ))) - -expandSubenvZeros :: SList STy env0 -> Subenv env0 env0Merge -> Ex env (Tup (D2E env0Merge)) -> Ex env (Tup (D2E env0)) -expandSubenvZeros _ SETop _ = ENil ext -expandSubenvZeros (SCons t ts) (SEYes sub) e = - ELet ext e $ - let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ - in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var) -expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (EZero ext t) - -assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[] -assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl -assertSubenvEmpty SETop = Refl -assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" - - ---------------------------------- ACCUMULATORS --------------------------------- - -fromArrayValId :: Maybe (ValId t) -> Maybe Int -fromArrayValId (Just (VIArr i _)) = Just i -fromArrayValId _ = Nothing - -accumPromote :: forall dt env sto proxy r. - proxy dt - -> Descr env sto - -> (forall stoRepl envPro. - (Select env stoRepl "merge" ~ '[]) - => Descr env stoRepl - -- ^ A revised environment description that switches - -- arrays (used in the OccEnv) that are currently on - -- "merge" storage, to "accum" storage. - -> SList STy envPro - -- ^ New entries on top of the original dual environment, - -- that house the accumulators for the promoted arrays in - -- the original environment. - -> Subenv (Select env sto "merge") envPro - -- ^ The promoted entries were merge entries in the - -- original environment. - -> Subenv (D2AcE (Select env stoRepl "accum")) (D2AcE (Select env sto "accum")) - -- ^ All entries that were accumulators are still - -- accumulators. - -> VarMap Int (D2AcE (Select env stoRepl "accum")) - -- ^ Accumulator map for _only_ the the newly allocated - -- accumulators. - -> (forall shbinds. - SList STy shbinds - -> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) - :> Append (D2AcE envPro) (D2 dt : Append shbinds (D2AcE (Select env sto "accum")))) - -- ^ A weakening that converts a computation in the - -- revised environment to one in the original environment - -- extended with some accumulators. - -> r) - -> r -accumPromote _ DTop k = k DTop SNil SETop SETop VarMap.empty (\_ -> WId) -accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of - -- Accumulators are left as-is - SAccum -> - accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> - k (storepl `DPush` (t, vid, SAccum)) - envpro - prosub - (SEYes accrevsub) - (VarMap.sink1 accumMap) - (\shbinds -> - autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum descr))) - (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) - (#pro :++: #d :++: #shb :++: #acc :++: #tl) - .> WCopy (wf shbinds) - .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum storepl))) - (#d :++: #shb :++: #acc :++: #tl) - (#acc :++: (#d :++: #shb :++: #tl))) - - SMerge -> case t of - -- Discrete values are left as-is - _ | isDiscrete t -> - accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap' wf -> - k (storepl `DPush` (t, vid, SDiscr)) - envpro - (SENo prosub) - accrevsub - accumMap' - wf - - -- Values with "merge" storage are promoted to an accumulator in envPro - _ -> - accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> - k (storepl `DPush` (t, vid, SAccum)) - (t `SCons` envpro) - (SEYes prosub) - (SENo accrevsub) - (let accumMap' = VarMap.sink1 accumMap - in case fromArrayValId vid of - Just i -> VarMap.insert i (STAccum t) IZ accumMap' - Nothing -> accumMap') - (\(shbinds :: SList _ shbinds) -> - let shbindsC = slistMap (\_ -> Const ()) shbinds - in - -- wf: - -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) - -- WCopy wf: - -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) - -- WPICK: ^ THESE TWO || - -- goal: | ARE EQUAL || - -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) - WCopy (wf shbinds) - .> WPick @(TAccum t) @(D2 dt : shbinds) (Const () `SCons` shbindsC) - (WId @(D2AcE (Select env1 stoRepl "accum")))) - - -- Discrete values are left as-is, nothing to do - SDiscr -> - accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> - k (storepl `DPush` (t, vid, SDiscr)) - envpro - prosub - accrevsub - accumMap - wf - where - isDiscrete :: STy t' -> Bool - isDiscrete = \case - STNil -> True - STPair a b -> isDiscrete a && isDiscrete b - STEither a b -> isDiscrete a && isDiscrete b - STMaybe a -> isDiscrete a - STArr _ a -> isDiscrete a - STScal st -> case st of - STI32 -> True - STI64 -> True - STF32 -> False - STF64 -> False - STBool -> True - STAccum{} -> False - - ----------------------------- RETURN TRIPLE FROM CHAD --------------------------- - -data Ret env0 sto t = - forall shbinds tapebinds env0Merge. - Ret (Bindings Ex (D1E env0) shbinds) -- shared binds - (Subenv shbinds tapebinds) - (Ex (Append shbinds (D1E env0)) (D1 t)) - (Subenv (Select env0 sto "merge") env0Merge) - (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) -deriving instance Show (Ret env0 sto t) - -data RetPair env0 sto env shbinds tapebinds t = - forall env0Merge. - RetPair (Ex (Append shbinds env) (D1 t)) - (Subenv (Select env0 sto "merge") env0Merge) - (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) -deriving instance Show (RetPair env0 sto env shbinds tapebinds t) - -data Rets env0 sto env list = - forall shbinds tapebinds. - Rets (Bindings Ex env shbinds) - (Subenv shbinds tapebinds) - (SList (RetPair env0 sto env shbinds tapebinds) list) -deriving instance Show (Rets env0 sto env list) - -weakenRetPair :: SList STy shbinds -> env :> env' - -> RetPair env0 sto env shbinds tapebinds t -> RetPair env0 sto env' shbinds tapebinds t -weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2 - -weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list -weakenRets w (Rets binds tapesub list) = - let (binds', _) = weakenBindings weakenExpr w binds - in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list) - -rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto t f. - Descr env0 sto - -> SList f b1 -> SList f b2 - -> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2 - -> RetPair env0 sto (Append b1 env) b2 tapebinds2 t - -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) t -rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair p sub d) - | Refl <- lemAppendAssoc @b2 @b1 @env = - RetPair p sub (weakenExpr (autoWeak - (#d (auto1 @(D2 t)) - &. #t2 (subList b2 subtape2) - &. #t1 (subList b1 subtape1) - &. #tl (d2ace (select SAccum descr))) - (#d :++: (#t2 :++: #tl)) - (#d :++: ((#t2 :++: #t1) :++: #tl))) - d) - -retConcat :: forall env0 sto list. Descr env0 sto -> SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list -retConcat _ SNil = Rets BTop SETop SNil -retConcat descr (SCons (Ret (b :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) p sub d) list) - | Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs - <- weakenRets (sinkWithBindings b) (retConcat descr list) - , Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0) - , Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum")) - = Rets (bconcat b binds) - (subenvConcat subtape subtape2) - (SCons (RetPair (weakenExpr (sinkWithBindings binds) p) - sub - (weakenExpr (WCopy (sinkWithSubenv subtape2)) d)) - (slistMap (rebaseRetPair descr (bindingsBinds b) (bindingsBinds binds) - subtape subtape2) - pairs)) - -freezeRet :: Descr env sto - -> Ret env sto t - -> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge")))) -freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) = - let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0 - e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2 - in letBinds e0' $ - EPair ext - (weakenExpr wInsertD2Ac e1) - (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #tape (subList (bindingsBinds e0) subtape) - &. #shbinds (bindingsBinds e0) - &. #d2ace (d2ace (select SAccum descr)) - &. #tl (desD1E descr)) - (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl) - (#shbinds :++: #d :++: #d2ace :++: #tl)) - e2') $ - expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ)) - - ----------------------------- THE CHAD TRANSFORMATION --------------------------- - -drev :: forall env sto t. - (?config :: CHADConfig) - => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) - -> Expr ValId env t -> Ret env sto t -drev des accumMap = \case - EVar _ t i -> - case conv2Idx des i of - Idx2Ac accI -> - Ret BTop - SETop - (EVar ext (d1 t) (conv1Idx i)) - (subenvNone (select SMerge des)) - (EAccum ext t SAPHere (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum t) (IS accI))) - - Idx2Me tupI -> - Ret BTop - SETop - (EVar ext (d1 t) (conv1Idx i)) - (subenvOnehot (select SMerge des) tupI) - (EPair ext (ENil ext) (EVar ext (d2 t) IZ)) - - Idx2Di _ -> - Ret BTop - SETop - (EVar ext (d1 t) (conv1Idx i)) - (subenvNone (select SMerge des)) - (ENil ext) - - ELet _ (rhs :: Expr _ _ a) body - | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des accumMap rhs - , ChosenStorage storage <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge - , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) body - , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 - , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) - , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> - subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> - let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in - Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') - (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody) - (weakenExpr wbody0' body1) - subBoth - (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #body (subList (bindingsBinds body0) subtapeBody) - &. #rhs (subList (bindingsBinds rhs0) subtapeRHS) - &. #tl (d2ace (select SAccum des))) - (#d :++: #body :++: #tl) - (#d :++: (#body :++: #rhs) :++: #tl)) - body2) $ - ELet ext - (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ - weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $ - plus_RHS_Body - (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ) - (EFst ext (EVar ext bodyResType (IS IZ)))) - - EPair _ a b - | Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) - <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil - , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) -> - subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B -> - Ret binds - subtape - (EPair ext a1 b1) - subBoth - (EMaybe ext - (zeroTup (subList (select SMerge des) subBoth)) - (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) - (weakenExpr (WCopy (wSinks' @[_,_])) a2)) $ - ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ))) - (weakenExpr (WCopy (wSinks' @[_,_,_])) b2)) $ - plus_A_B - (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ)) - (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ)) - (EVar ext (STMaybe (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ)) - - EFst _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , STPair t1 t2 <- typeOf e -> - Ret e0 - subtape - (EFst ext e1) - sub - (ELet ext (EJust ext (EPair ext (EVar ext (d2 t1) IZ) (EZero ext t2))) $ - weakenExpr (WCopy WSink) e2) - - ESnd _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , STPair t1 t2 <- typeOf e -> - Ret e0 - subtape - (ESnd ext e1) - sub - (ELet ext (EJust ext (EPair ext (EZero ext t1) (EVar ext (d2 t2) IZ))) $ - weakenExpr (WCopy WSink) e2) - - ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext) - - EInl _ t2 e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> - Ret e0 - subtape - (EInl ext (d1 t2) e1) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ) - (weakenExpr (WCopy (wSinks' @[_,_])) e2) - (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr")) - (EVar ext (STMaybe (STEither (d2 (typeOf e)) (d2 t2))) IZ)) - - EInr _ t1 e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> - Ret e0 - subtape - (EInr ext (d1 t1) e1) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ) - (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl") - (weakenExpr (WCopy (wSinks' @[_,_])) e2)) - (EVar ext (STMaybe (STEither (d2 t1) (d2 (typeOf e)))) IZ)) - - ECase _ e (a :: Expr _ _ t) b - | STEither t1 t2 <- typeOf e - , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap e - , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge - , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge - , let (bindids1, bindids2) = validSplitEither (extOf e) - , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des accumMap t1 storage1 bindids1 a - , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des accumMap t2 storage2 bindids2 b - , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) - , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) - , let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA) - , let tapeB = tapeTy (subList (bindingsBinds b0) subtapeB) - , let collectA = bindingsCollect a0 subtapeA - , let collectB = bindingsCollect b0 subtapeB - , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB) - , let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0 - , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0 - -> - subenvPlus (select SMerge des) subA subB $ \subAB sAB_A sAB_B _ -> - subenvPlus (select SMerge des) subAB subE $ \subOut _ _ plus_AB_E -> - let tCaseRet = STPair (tTup (d2e (subList (select SMerge des) subAB))) (STEither (d2 t1) (d2 t2)) in - Ret (e0 `BPush` - (tPrimal, - ECase ext e1 - (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0')))) - (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0')))))) - (SEYes subtapeE) - (EFst ext (EVar ext tPrimal IZ)) - subOut - (ELet ext - (ECase ext (ESnd ext (EVar ext tPrimal (IS IZ))) - (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds a0) subtapeA) IZ - in letBinds rebinds $ - ELet ext - (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_tape : D2 t : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend (subList (bindingsBinds a0) subtapeA) prerebinds) @> IS IZ)) $ - ELet ext - (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #ta0 (subList (bindingsBinds a0) subtapeA) - &. #prea0 prerebinds - &. #recon (tapeA `SCons` d2 (typeOf a) `SCons` SNil) - &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE) - &. #tl (d2ace (select SAccum des))) - (#d :++: #ta0 :++: #tl) - (#d :++: (#ta0 :++: #prea0) :++: #recon :++: #binds :++: #tl)) - a2) $ - EPair ext - (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $ - EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ)) - (EInl ext (d2 t2) - (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ)))) - (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds b0) subtapeB) IZ - in letBinds rebinds $ - ELet ext - (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_tape : D2 t : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend (subList (bindingsBinds b0) subtapeB) prerebinds) @> IS IZ)) $ - ELet ext - (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #tb0 (subList (bindingsBinds b0) subtapeB) - &. #preb0 prerebinds - &. #recon (tapeB `SCons` d2 (typeOf a) `SCons` SNil) - &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE) - &. #tl (d2ace (select SAccum des))) - (#d :++: #tb0 :++: #tl) - (#d :++: (#tb0 :++: #preb0) :++: #recon :++: #binds :++: #tl)) - b2) $ - EPair ext - (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $ - EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ)) - (EInr ext (d2 t1) - (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))))) $ - ELet ext - (ELet ext (EJust ext (ESnd ext (EVar ext tCaseRet IZ))) $ - weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $ - plus_AB_E - (EFst ext (EVar ext tCaseRet (IS IZ))) - (EVar ext (tTup (d2e (subList (select SMerge des) subE))) IZ)) - - EConst _ t val -> - Ret BTop - SETop - (EConst ext t val) - (subenvNone (select SMerge des)) - (ENil ext) - - EOp _ op e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> - case d2op op of - Linear d2opfun -> - Ret e0 - subtape - (d1op op e1) - sub - (ELet ext (d2opfun (EVar ext (d2 (opt2 op)) IZ)) - (weakenExpr (WCopy WSink) e2)) - Nonlinear d2opfun -> - Ret (e0 `BPush` (d1 (typeOf e), e1)) - (SEYes subtape) - (d1op op $ EVar ext (d1 (typeOf e)) IZ) - sub - (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ)) - (EVar ext (d2 (opt2 op)) IZ)) - (weakenExpr (WCopy (wSinks' @[_,_])) e2)) - - ECustom _ _ _ storety _ pr du a b - -- allowed to ignore a2 because 'a' is the part of the input that is inactive - | Rets binds subtape (RetPair a1 _ _ `SCons` RetPair b1 bsub b2 `SCons` SNil) - <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil -> - Ret (binds `BPush` (typeOf a1, a1) - `BPush` (typeOf b1, weakenExpr WSink b1) - `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) - `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ))) - (SEYes (SENo (SENo (SENo subtape)))) - (EFst ext (EVar ext (typeOf pr) (IS IZ))) - bsub - (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ - weakenExpr (WCopy (WSink .> WSink)) b2) - - EError _ t s -> - Ret BTop - SETop - (EError ext (d1 t) s) - (subenvNone (select SMerge des)) - (ENil ext) - - EConstArr _ n t val -> - Ret BTop - SETop - (EConstArr ext n t val) - (subenvNone (select SMerge des)) - (ENil ext) - - EBuild _ (ndim :: SNat ndim) she (orige :: Expr _ _ eltty) - | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des accumMap she -- allowed to ignore she2 here because she has a discrete result - , let eltty = typeOf orige - , shty :: STy shty <- tTup (sreplicate ndim tIx) - , Refl <- indexTupD1Id ndim -> - deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') -> - let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in - subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> - accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> - let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in - case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> - case assertSubenvEmpty sub of { Refl -> - let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in - let collectexpr = bindingsCollect e0 subtapeE in - Ret (BTop `BPush` (shty, letBinds she0 she1) - `BPush` (STArr ndim (STPair (d1 eltty) tapety) - ,EBuild ext ndim - (EVar ext shty IZ) - (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil) - &. #sh (shty `SCons` SNil) - &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes)) - (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#ix :++: #sh :++: #d1env)) - e0)) $ - let w = autoWeak (#ix (shty `SCons` SNil) - &. #sh (shty `SCons` SNil) - &. #e0 (bindingsBinds e0) - &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes)) - (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#e0 :++: #ix :++: #sh :++: #d1env) - in EPair ext (weakenExpr w e1) (collectexpr w))) - `BPush` (STArr ndim tapety, emap (ESnd ext (EVar ext (STPair (d1 eltty) tapety) IZ)) - (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ))) - (SEYes (SENo (SEYes SETop))) - (emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ)) - (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) - (subenvCompose subMergeUsed proSub) - (let sinkOverEnvPro = wSinks @(TArr ndim (D2 eltty) : D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in - EMaybe ext - (zeroTup envPro) - (ESnd ext $ - uninvertTup (d2e envPro) (STArr ndim STNil) $ - makeAccumulators @_ @_ @(TArr ndim TNil) envPro $ - EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS (IS IZ)))) $ - -- the cotangent for this element - ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ)) - (EVar ext shty IZ)) $ - -- the tape for this element - ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS (IS IZ))) - (EVar ext shty (IS IZ))) $ - let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ - in letBinds rebinds $ - weakenExpr (autoWeak (#d (auto1 @(D2 eltty)) - &. #pro (d2ace envPro) - &. #etape (subList (bindingsBinds e0) subtapeE) - &. #prerebinds prerebinds - &. #tape (auto1 @(Tape e_tape)) - &. #ix (auto1 @shty) - &. #darr (auto1 @(TArr ndim (D2 eltty))) - &. #mdarr (auto1 @(TMaybe (TArr ndim (D2 eltty)))) - &. #tapearr (auto1 @(TArr ndim (Tape e_tape))) - &. #sh (auto1 @shty) - &. #d2acUsed (d2ace (select SAccum usedDes)) - &. #d2acEnv (d2ace (select SAccum des))) - (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) - ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #mdarr :++: #tapearr :++: #sh :++: #d2acEnv) - .> wPro (subList (bindingsBinds e0) subtapeE)) - e2) - (EVar ext (d2 (STArr ndim eltty)) IZ)) - }} - - EUnit _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> - Ret e0 - subtape - (EUnit ext e1) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $ - weakenExpr (WCopy (WSink .> WSink)) e2) - (EVar ext (STMaybe (STArr SZ (d2 (typeOf e)))) IZ)) - - EReplicate1Inner _ en e - -- We're allowed to ignore en2 here because the output of 'ei' is discrete. - | Rets binds subtape (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil) - <- retConcat des $ drev des accumMap en `SCons` drev des accumMap e `SCons` SNil - , let STArr ndim eltty = typeOf e -> - Ret binds - subtape - (EReplicate1Inner ext en1 e1) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext (EFold1Inner ext Commut - (EPlus ext eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) - (EZero ext eltty) - (EVar ext (STArr (SS ndim) (d2 eltty)) IZ))) $ - weakenExpr (WCopy (WSink .> WSink)) e2) - (EVar ext (d2 (STArr (SS ndim) eltty)) IZ)) - - EIdx0 _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , STArr _ t <- typeOf e -> - Ret e0 - subtape - (EIdx0 ext e1) - sub - (ELet ext (EJust ext (EUnit ext (EVar ext (d2 t) IZ))) $ - weakenExpr (WCopy WSink) e2) - - EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead" - {- - EIdx1 _ e ei - -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. - | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) - <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil - , STArr (SS n) eltty <- typeOf e -> - Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1) - `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ))) - (SEYes (SENo subtape)) - (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ)) - (weakenExpr (WSink .> WSink) ei1)) - sub - (ELet ext (ebuildUp1 n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) - (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) - (EVar ext (STArr n (d2 eltty)) (IS IZ))) $ - weakenExpr (WCopy (WSink .> WSink)) e2) - -} - - EIdx _ e ei - -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. - | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) - <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil - , STArr n eltty <- typeOf e - , Refl <- indexTupD1Id n - , let tIxN = tTup (sreplicate n tIx) -> - Ret (binds `BPush` (STArr n (d1 eltty), e1) - `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ)) - `BPush` (tIxN, weakenExpr (WSink .> WSink) ei1)) - (SEYes (SEYes (SENo subtape))) - (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) IZ)) - sub - (ELet ext (EOneHot ext (STArr n eltty) (SAPArrIdx SAPHere n) - (EPair ext (EPair ext (EVar ext tIxN (IS IZ)) (EVar ext tIxN (IS (IS IZ)))) - (ENil ext)) - (EVar ext (d2 eltty) IZ)) $ - weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) - - EShape _ e - -- Allowed to ignore e2 here because the output of EShape is discrete, - -- hence we'd be passing a zero cotangent to e2 anyway. - | Ret e0 subtape e1 _ _ <- drev des accumMap e - , STArr n _ <- typeOf e - , Refl <- indexTupD1Id n -> - Ret e0 - subtape - (EShape ext e1) - (subenvNone (select SMerge des)) - (ENil ext) - - ESum1Inner _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , STArr (SS n) t <- typeOf e -> - Ret (e0 `BPush` (STArr (SS n) t, e1) - `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ))) - (SEYes (SENo subtape)) - (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ))) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext (EReplicate1Inner ext - (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS (IS IZ)))) - (EVar ext (STArr n (d2 t)) IZ))) $ - weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) - (EVar ext (d2 (STArr n t)) IZ)) - - EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e - EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e - - -- These should be the next to be implemented, I think - EFold1Inner{} -> err_unsupported "EFold1Inner" - - ENothing{} -> err_unsupported "ENothing" - EJust{} -> err_unsupported "EJust" - EMaybe{} -> err_unsupported "EMaybe" - - EWith{} -> err_accum - EAccum{} -> err_accum - EZero{} -> err_monoid - EPlus{} -> err_monoid - EOneHot{} -> err_monoid - - where - err_accum = error "Accumulator operations unsupported in the source program" - err_monoid = error "Monoid operations unsupported in the source program" - err_unsupported s = error $ "CHAD: unsupported " ++ s - - deriv_extremum :: ScalIsNumeric t' ~ True - => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t'))) - -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t')) - deriv_extremum extremum e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , at@(STArr (SS n) t@(STScal st)) <- typeOf e - , let at' = STArr n t - , let tIxN = tTup (sreplicate (SS n) tIx) = - Ret (e0 `BPush` (at, e1) - `BPush` (at', extremum (EVar ext at IZ))) - (SEYes (SEYes subtape)) - (EVar ext at' IZ) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext - (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS (IS IZ))))) $ - eif (EOp ext (OEq st) (EPair ext - (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ)) - (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ))))) - (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) - (EZero ext t))) $ - weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2) - (EVar ext (d2 at') IZ)) - -data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s) - -data RetScoped env0 sto a s t = - forall shbinds tapebinds env0Merge. - RetScoped - (Bindings Ex (D1E (a : env0)) shbinds) -- shared binds - (Subenv shbinds tapebinds) - (Ex (Append shbinds (D1E (a : env0))) (D1 t)) - (Subenv (Select env0 sto "merge") env0Merge) - -- ^ merge contributions to the _enclosing_ merge environment - (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) - (If (s == "discr") (Tup (D2E env0Merge)) - (TPair (Tup (D2E env0Merge)) (D2 a)))) - -- ^ the merge contributions, plus the cotangent to the argument - -- (if there is any) -deriving instance Show (RetScoped env0 sto a s t) - -drevScoped :: forall a s env sto t. - (?config :: CHADConfig) - => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) - -> STy a -> Storage s -> Maybe (ValId a) - -> Expr ValId (a : env) t - -> RetScoped env sto a s t -drevScoped des accumMap argty argsto argids expr = case argsto of - SMerge - | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr -> - case sub of - SEYes sub' -> RetScoped e0 subtape e1 sub' e2 - SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (EZero ext argty)) - - SAccum - | Just (VIArr i _) <- argids - , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap - , Just Refl <- testEquality foundTy (STAccum argty) - , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) expr -> - RetScoped e0 subtape e1 sub $ - let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in - ELet ext (EVar ext (STAccum argty) (WSink .> wtapebinds @> idx)) $ - weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #body (subList (bindingsBinds e0) subtape) - &. #ac (auto1 @(TAccum a)) - &. #tl (d2ace (select SAccum des))) - (#d :++: #body :++: #ac :++: #tl) - (#ac :++: #d :++: #body :++: #tl)) - -- Our contribution to the binding's cotangent _here_ is - -- zero, because we're contributing to an earlier binding - -- of the same value instead. - (EPair ext e2 (EZero ext argty)) - - | let accumMap' = case argids of - Just (VIArr i _) -> VarMap.insert i (STAccum argty) IZ (VarMap.sink1 accumMap) - _ -> VarMap.sink1 accumMap - , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' expr -> - RetScoped e0 subtape e1 sub $ - EWith ext argty (EZero ext argty) $ - weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #body (subList (bindingsBinds e0) subtape) - &. #ac (auto1 @(TAccum a)) - &. #tl (d2ace (select SAccum des))) - (#d :++: #body :++: #ac :++: #tl) - (#ac :++: #d :++: #body :++: #tl)) - e2 - - SDiscr - | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr -> - RetScoped e0 subtape e1 sub e2 |
