diff options
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 1068 |
1 files changed, 642 insertions, 426 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index df792ce..143376a 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -11,6 +11,7 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} @@ -33,7 +34,6 @@ module CHAD ( import Data.Functor.Const import Data.Some -import Data.Type.Bool (If) import Data.Type.Equality (type (==), testEquality) import GHC.Stack (HasCallStack) @@ -42,6 +42,7 @@ import AST import AST.Bindings import AST.Count import AST.Env +import AST.Sparse import AST.Weaken.Auto import CHAD.Accum import CHAD.EnvDescr @@ -62,15 +63,21 @@ tapeTy :: SList STy binds -> STy (Tape binds) tapeTy SNil = STNil tapeTy (SCons t ts) = STPair t (tapeTy ts) -bindingsCollectTape :: Bindings f env binds -> Subenv binds tapebinds - -> Append binds env :> env2 -> Ex env2 (Tape tapebinds) -bindingsCollectTape BTop SETop _ = ENil ext -bindingsCollectTape (BPush binds (t, _)) (SEYes sub) w = +bindingsCollectTape :: SList STy binds -> Subenv binds tapebinds + -> binds :> env2 -> Ex env2 (Tape tapebinds) +bindingsCollectTape SNil SETop _ = ENil ext +bindingsCollectTape (t `SCons` binds) (SEYesR sub) w = EPair ext (EVar ext t (w @> IZ)) (bindingsCollectTape binds sub (w .> WSink)) -bindingsCollectTape (BPush binds _) (SENo sub) w = +bindingsCollectTape (_ `SCons` binds) (SENo sub) w = bindingsCollectTape binds sub (w .> WSink) +-- bindingsCollectTape' :: forall f env binds tapebinds env2. Bindings f env binds -> Subenv binds tapebinds +-- -> Append binds env :> env2 -> Ex env2 (Tape tapebinds) +-- bindingsCollectTape' binds sub w +-- | Refl <- lemAppendNil @binds +-- = bindingsCollectTape (bindingsBinds binds) sub (w .> wCopies @_ @_ @'[] (bindingsBinds binds) (WClosed @env)) + -- In order from large to small: i.e. in reverse order from what we want, -- because in a Bindings, the head of the list is the bottom-most entry. type family TapeUnfoldings binds where @@ -227,26 +234,37 @@ data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) d2op :: SOp a t -> D2Op a t d2op op = case op of - OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EJust ext (EPair ext d d) + OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EPair ext d d OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> - EJust ext (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) - (EOp ext (OMul t) (EPair ext (EFst ext e) d))) + EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) + (EOp ext (OMul t) (EPair ext (EFst ext e) d)) ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d - OLt t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) - OLe t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) - OEq t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) + OLt t -> Linear $ \_ -> pairZero t + OLe t -> Linear $ \_ -> pairZero t + OEq t -> Linear $ \_ -> pairZero t ONot -> Linear $ \_ -> ENil ext - OAnd -> Linear $ \_ -> ENothing ext (STPair STNil STNil) - OOr -> Linear $ \_ -> ENothing ext (STPair STNil STNil) + OAnd -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + OOr -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) OIf -> Linear $ \_ -> ENil ext - ORound64 -> Linear $ \_ -> EConst ext STF64 0.0 + ORound64 -> Linear $ \_ -> EZero ext (SMTScal STF64) (ENil ext) OToFl64 -> Linear $ \_ -> ENil ext ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d) OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d) OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d) - OIDiv t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil) - OMod t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil) + OIDiv t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + OMod t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext) where + pairZero :: SScalTy a -> Ex env (D2 (TPair (TScal a) (TScal a))) + pairZero t = ziNil t $ EPair ext (EZero ext (d2M (STScal t)) (ENil ext)) + (EZero ext (d2M (STScal t)) (ENil ext)) + where + ziNil :: SScalTy a -> (ZeroInfo (D2s a) ~ TNil => r) -> r + ziNil STI32 k = k + ziNil STI64 k = k + ziNil STF32 k = k + ziNil STF64 k = k + ziNil STBool k = k + d2opUnArrangeInt :: SScalTy a -> (D2s a ~ TScal a => D2Op (TScal a) t) -> D2Op (TScal a) t @@ -261,11 +279,11 @@ d2op op = case op of -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t) -> D2Op (TPair (TScal a) (TScal a)) t d2opBinArrangeInt ty float = case ty of - STI32 -> Linear $ \_ -> ENothing ext (STPair STNil STNil) - STI64 -> Linear $ \_ -> ENothing ext (STPair STNil STNil) + STI32 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + STI64 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) STF32 -> float STF64 -> float - STBool -> Linear $ \_ -> ENothing ext (STPair STNil STNil) + STBool -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) floatingD2 :: ScalIsFloating a ~ True => SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r @@ -293,7 +311,7 @@ conv1Idx (IS i) = IS (conv1Idx i) data Idx2 env sto t = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) - | Idx2Me (Idx (Select env sto "merge") t) + | Idx2Me (Idx (D2E (Select env sto "merge")) (D2 t)) | Idx2Di (Idx (Select env sto "discr") t) conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t @@ -314,64 +332,160 @@ conv2Idx (DPush des (_, _, SDiscr)) (IS i) = Idx2Di j -> Idx2Di (IS j) conv2Idx DTop i = case i of {} - ------------------------------------- MONOIDS ----------------------------------- - -zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0)) -zeroTup SNil = ENil ext -zeroTup (t `SCons` env) = EPair ext (zeroTup env) (ezeroD2 t) - - ------------------------------------- SUBENVS ----------------------------------- - -subenvPlus :: SList STy env - -> Subenv env env1 -> Subenv env env2 - -> (forall env3. Subenv env env3 - -> Subenv env3 env1 - -> Subenv env3 env2 - -> (Ex exenv (Tup (D2E env1)) - -> Ex exenv (Tup (D2E env2)) - -> Ex exenv (Tup (D2E env3))) +opt2UnSparse :: SOp a b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b) +opt2UnSparse = go . opt2 + where + go :: STy b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b) + go (STScal STI32) SpAbsent = \_ -> ENil ext + go (STScal STI64) SpAbsent = \_ -> ENil ext + go (STScal STF32) SpAbsent = \_ -> EZero ext (SMTScal STF32) (ENil ext) + go (STScal STF64) SpAbsent = \_ -> EZero ext (SMTScal STF64) (ENil ext) + go (STScal STBool) SpAbsent = \_ -> ENil ext + go (STScal STF32) SpScal = id + go (STScal STF64) SpScal = id + go STNil _ = \_ -> ENil ext + go (STPair t1 t2) (SpPair s1 s2) = \e -> eunPair e $ \_ e1 e2 -> EPair ext (go t1 s1 e1) (go t2 s2 e2) + go t _ = error $ "Primitive operations that return " ++ show t ++ " are scary" + + +----------------------------------- SPARSITY ----------------------------------- + +expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a) +expandSparse t sp _ e | Just Refl <- isDense (d2M t) sp = e +expandSparse t (SpSparse sp) epr e = + EMaybe ext + (EZero ext (d2M t) (d2zeroInfo t epr)) + (expandSparse t sp (weakenExpr WSink epr) (EVar ext (applySparse sp (d2 t)) IZ)) + e +expandSparse t SpAbsent epr _ = EZero ext (d2M t) (d2zeroInfo t epr) +expandSparse (STPair t1 t2) (SpPair s1 s2) epr e = + eunPair epr $ \w1 epr1 epr2 -> + eunPair (weakenExpr w1 e) $ \w2 e1 e2 -> + EPair ext (expandSparse t1 s1 (weakenExpr w2 epr1) e1) + (expandSparse t2 s2 (weakenExpr w2 epr2) e2) +expandSparse (STEither t1 t2) (SpLEither s1 s2) epr e = + ELCase ext e + (EZero ext (d2M (STEither t1 t2)) (ENil ext)) + (ECase ext (weakenExpr WSink epr) + (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ)))) + (EError ext (d2 (STEither t1 t2)) "expspa r<-dl")) + (ECase ext (weakenExpr WSink epr) + (EError ext (d2 (STEither t1 t2)) "expspa l<-dr") + (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) +expandSparse (STLEither t1 t2) (SpLEither s1 s2) epr e = + ELCase ext e + (EZero ext (d2M (STEither t1 t2)) (ENil ext)) + (ELCase ext (weakenExpr WSink epr) + (EError ext (d2 (STEither t1 t2)) "expspa ln<-dl") + (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ)))) + (EError ext (d2 (STEither t1 t2)) "expspa lr<-dl")) + (ELCase ext (weakenExpr WSink epr) + (EError ext (d2 (STEither t1 t2)) "expspa ln<-dr") + (EError ext (d2 (STEither t1 t2)) "expspa ll<-dr") + (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) +expandSparse (STMaybe t) (SpMaybe s) epr e = + EMaybe ext + (ENothing ext (d2 t)) + (let epr' = EMaybe ext (EError ext (d1 t) "expspa n<-dj") (EVar ext (d1 t) IZ) epr + in EJust ext (expandSparse t s (weakenExpr WSink epr') (EVar ext (applySparse s (d2 t)) IZ))) + e +expandSparse (STArr _ t) (SpArr s) epr e = + ezipWith (expandSparse t s (EVar ext (d1 t) (IS IZ)) (EVar ext (applySparse s (d2 t)) IZ)) epr e +expandSparse (STScal STF32) SpScal _ e = e +expandSparse (STScal STF64) SpScal _ e = e +expandSparse (STAccum{}) _ _ _ = error "accumulators not allowed in source program" + +subenvPlus :: SBool req1 -> SBool req2 + -> SList SMTy env + -> SubenvS env env1 -> SubenvS env env2 + -> (forall env3. SubenvS env env3 + -> Injection req1 (Tup env1) (Tup env3) + -> Injection req2 (Tup env2) (Tup env3) + -> (forall e. Ex e (Tup env1) -> Ex e (Tup env2) -> Ex e (Tup env3)) -> r) -> r -subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext) -subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> +-- don't destroy effects! +subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\a b -> use a $ use b $ ENil ext) + +subenvPlus req1 req2 (SCons _ env) (SENo sub1) (SENo sub2) k = + subenvPlus req1 req2 env sub1 sub2 $ \sub3 s31 s32 pl -> k (SENo sub3) s31 s32 pl -subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 -> - ELet ext e1 $ - EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) - (weakenExpr WSink e2)) - (ESnd ext (EVar ext (typeOf e1) IZ)) -subenvPlus (SCons _ env) (SENo sub1) (SEYes sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sub3) (SENo s31) (SEYes s32) $ \e1 e2 -> - ELet ext e2 $ - EPair ext (pl (weakenExpr WSink e1) - (EFst ext (EVar ext (typeOf e2) IZ))) - (ESnd ext (EVar ext (typeOf e2) IZ)) -subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sub3) (SEYes s31) (SEYes s32) $ \e1 e2 -> - ELet ext e1 $ - ELet ext (weakenExpr WSink e2) $ - EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) - (EFst ext (EVar ext (typeOf e2) IZ))) - (EPlus ext (d2M t) - (ESnd ext (EVar ext (typeOf e1) (IS IZ))) - (ESnd ext (EVar ext (typeOf e2) IZ))) - -expandSubenvZeros :: SList STy env0 -> Subenv env0 env0Merge -> Ex env (Tup (D2E env0Merge)) -> Ex env (Tup (D2E env0)) -expandSubenvZeros _ SETop _ = ENil ext -expandSubenvZeros (SCons t ts) (SEYes sub) e = - ELet ext e $ - let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ - in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var) -expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (ezeroD2 t) - -assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[] + +subenvPlus req1 SF (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k = + subenvPlus req1 SF env sub1 sub2 $ \sub3 minj13 _ pl -> + k (SEYes sp1 sub3) + (withInj minj13 $ \inj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) e1b) + Noinj + (\e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (ESnd ext (EVar ext (typeOf e1) IZ))) +subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k + | Just zero1 <- cheapZero (applySparse sp1 t) = + subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl -> + k (SEYes sp1 sub3) + (withInj minj13 $ \inj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) e1b) + (Inj $ \e2 -> EPair ext (inj23 e2) zero1) + (\e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (ESnd ext (EVar ext (typeOf e1) IZ))) + | otherwise = + subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl -> + k (SEYes (SpSparse sp1) sub3) + (withInj minj13 $ \inj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) (EJust ext e1b)) + (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t)))) + (\e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ)))) + +subenvPlus req1 req2 (SCons t env) sub1@SENo{} sub2@SEYes{} k = + subenvPlus req2 req1 (SCons t env) sub2 sub1 $ \sub3 minj23 minj13 pl -> + k sub3 minj13 minj23 (flip pl) + +subenvPlus req1 req2 (SCons t env) (SEYes sp1 sub1) (SEYes sp2 sub2) k = + subenvPlus req1 req2 env sub1 sub2 $ \sub3 minj13 minj23 pl -> + sparsePlusS req1 req2 t sp1 sp2 $ \sp3 mTinj13 mTinj23 plus -> + k (SEYes sp3 sub3) + (withInj2 minj13 mTinj13 $ \inj13 tinj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) (tinj13 e1b)) + (withInj2 minj23 mTinj23 $ \inj23 tinj23 -> + \e2 -> eunPair e2 $ \_ e2a e2b -> + EPair ext (inj23 e2a) (tinj23 e2b)) + (\e1 e2 -> + ELet ext e1 $ + ELet ext (weakenExpr WSink e2) $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) + (EFst ext (EVar ext (typeOf e2) IZ))) + (plus + (ESnd ext (EVar ext (typeOf e1) (IS IZ))) + (ESnd ext (EVar ext (typeOf e2) IZ)))) + +expandSubenvZeros :: D1E env0 :> env -> SList STy env0 -> SubenvS (D2E env0) contribs + -> Ex env (Tup contribs) -> Ex env (Tup (D2E env0)) +expandSubenvZeros _ SNil SETop _ = ENil ext +expandSubenvZeros w (SCons t ts) (SEYes sp sub) e = + eunPair e $ \w1 e1 e2 -> + EPair ext + (expandSubenvZeros (w1 .> WPop w) ts sub e1) + (expandSparse t sp (EVar ext (d1 t) (w1 .> w @> IZ)) e2) +expandSubenvZeros w (SCons t ts) (SENo sub) e = + EPair ext + (expandSubenvZeros (WPop w) ts sub e) + (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) + +assertSubenvEmpty :: HasCallStack => Subenv' s env env' -> env' :~: '[] assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl assertSubenvEmpty SETop = Refl assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" @@ -407,8 +521,8 @@ accumPromote :: forall dt env sto proxy r. -- accumulators. -> (forall shbinds. SList STy shbinds - -> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) - :> Append (D2AcE envPro) (D2 dt : Append shbinds (D2AcE (Select env sto "accum")))) + -> (dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) + :> Append (D2AcE envPro) (dt : Append shbinds (D2AcE (Select env sto "accum")))) -- ^ A weakening that converts a computation in the -- revised environment to one in the original environment -- extended with some accumulators. @@ -422,14 +536,14 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of k (storepl `DPush` (t, vid, SAccum)) envpro prosub - (SEYes accrevsub) + (SEYesR accrevsub) (VarMap.sink1 accumMap) (\shbinds -> - autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr))) + autoWeak (#pro (d2ace envpro) &. #d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr))) (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) (#pro :++: #d :++: #shb :++: #acc :++: #tl) .> WCopy (wf shbinds) - .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl))) + .> autoWeak (#d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl))) (#d :++: #shb :++: #acc :++: #tl) (#acc :++: (#d :++: #shb :++: #tl))) @@ -449,7 +563,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> k (storepl `DPush` (t, vid, SAccum)) (t `SCons` envpro) - (SEYes prosub) + (SEYesR prosub) (SENo accrevsub) (let accumMap' = VarMap.sink1 accumMap in case fromArrayValId vid of @@ -466,7 +580,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of -- goal: | ARE EQUAL || -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) WCopy (wf shbinds) - .> WPick @(TAccum (D2 t)) @(D2 dt : shbinds) (Const () `SCons` shbindsC) + .> WPick @(TAccum (D2 t)) @(dt : shbinds) (Const () `SCons` shbindsC) (WId @(D2AcE (Select env1 stoRepl "accum")))) -- Discrete values are left as-is, nothing to do @@ -498,21 +612,41 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of ---------------------------- RETURN TRIPLE FROM CHAD --------------------------- -data Ret env0 sto t = - forall shbinds tapebinds env0Merge. +data Ret env0 sto sd t = + forall shbinds tapebinds contribs. Ret (Bindings Ex (D1E env0) shbinds) -- shared binds (Subenv shbinds tapebinds) (Ex (Append shbinds (D1E env0)) (D1 t)) - (Subenv (Select env0 sto "merge") env0Merge) - (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) -deriving instance Show (Ret env0 sto t) + (SubenvS (D2E (Select env0 sto "merge")) contribs) + (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)) +deriving instance Show (Ret env0 sto sd t) -data RetPair env0 sto env shbinds tapebinds t = - forall env0Merge. - RetPair (Ex (Append shbinds env) (D1 t)) - (Subenv (Select env0 sto "merge") env0Merge) - (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) -deriving instance Show (RetPair env0 sto env shbinds tapebinds t) +type data TyTyPair = MkTyTyPair Ty Ty + +data SingleRet env0 sto (pair :: TyTyPair) = + forall shbinds tapebinds. + SingleRet + (Bindings Ex (D1E env0) shbinds) -- shared binds + (Subenv shbinds tapebinds) + (RetPair env0 sto (D1E env0) shbinds tapebinds pair) + +-- pattern Ret1 :: forall env0 sto Bindings Ex (D1E env0) shbinds +-- -> Subenv shbinds tapebinds +-- -> Ex (Append shbinds (D1E env0)) (D1 t) +-- -> SubenvS (D2E (Select env0 sto "merge")) contribs +-- -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs) +-- -> SingleRet env0 sto (MkTyTyPair sd t) +-- pattern Ret1 e0 subtape e1 sub e2 = SingleRet e0 subtape (RetPair e1 sub e2) +-- {-# COMPLETE Ret1 #-} + +data RetPair env0 sto env shbinds tapebinds (pair :: TyTyPair) where + RetPair :: forall sd t contribs -- existentials + env0 sto env shbinds tapebinds. -- universals + Ex (Append shbinds env) (D1 t) + -> SubenvS (D2E (Select env0 sto "merge")) contribs + -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs) + -> RetPair env0 sto env shbinds tapebinds (MkTyTyPair sd t) +deriving instance Show (RetPair env0 sto env shbinds tapebinds pair) data Rets env0 sto env list = forall shbinds tapebinds. @@ -521,8 +655,11 @@ data Rets env0 sto env list = (SList (RetPair env0 sto env shbinds tapebinds) list) deriving instance Show (Rets env0 sto env list) +toSingleRet :: Ret env0 sto sd t -> SingleRet env0 sto (MkTyTyPair sd t) +toSingleRet (Ret e0 subtape e1 sub e2) = SingleRet e0 subtape (RetPair e1 sub e2) + weakenRetPair :: SList STy shbinds -> env :> env' - -> RetPair env0 sto env shbinds tapebinds t -> RetPair env0 sto env' shbinds tapebinds t + -> RetPair env0 sto env shbinds tapebinds pair -> RetPair env0 sto env' shbinds tapebinds pair weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2 weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list @@ -530,104 +667,137 @@ weakenRets w (Rets binds tapesub list) = let (binds', _) = weakenBindings weakenExpr w binds in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list) -rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto t f. +rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto pair f. Descr env0 sto -> SList f b1 -> SList f b2 -> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2 - -> RetPair env0 sto (Append b1 env) b2 tapebinds2 t - -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) t -rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair p sub d) + -> RetPair env0 sto (Append b1 env) b2 tapebinds2 pair + -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) pair +rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair @sd e1 sub e2) | Refl <- lemAppendAssoc @b2 @b1 @env = - RetPair p sub (weakenExpr (autoWeak - (#d (auto1 @(D2 t)) - &. #t2 (subList b2 subtape2) - &. #t1 (subList b1 subtape1) - &. #tl (d2ace (select SAccum descr))) - (#d :++: (#t2 :++: #tl)) - (#d :++: ((#t2 :++: #t1) :++: #tl))) - d) - -retConcat :: forall env0 sto list. Descr env0 sto -> SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list + RetPair e1 sub + (weakenExpr (autoWeak + (#d (auto1 @sd) + &. #t2 (subList b2 subtape2) + &. #t1 (subList b1 subtape1) + &. #tl (d2ace (select SAccum descr))) + (#d :++: (#t2 :++: #tl)) + (#d :++: ((#t2 :++: #t1) :++: #tl))) + e2) + +retConcat :: forall env0 sto list. Descr env0 sto -> SList (SingleRet env0 sto) list -> Rets env0 sto (D1E env0) list retConcat _ SNil = Rets BTop SETop SNil -retConcat descr (SCons (Ret (b :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) p sub d) list) +retConcat descr (SCons (SingleRet (e0 :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) (RetPair e1 sub e2)) list) | Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs - <- weakenRets (sinkWithBindings b) (retConcat descr list) + <- weakenRets (sinkWithBindings e0) (retConcat descr list) , Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0) , Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum")) - = Rets (bconcat b binds) + = Rets (bconcat e0 binds) (subenvConcat subtape subtape2) - (SCons (RetPair (weakenExpr (sinkWithBindings binds) p) + (SCons (RetPair (weakenExpr (sinkWithBindings binds) e1) sub - (weakenExpr (WCopy (sinkWithSubenv subtape2)) d)) - (slistMap (rebaseRetPair descr (bindingsBinds b) (bindingsBinds binds) + (weakenExpr (WCopy (sinkWithSubenv subtape2)) e2)) + (slistMap (rebaseRetPair descr (bindingsBinds e0) (bindingsBinds binds) subtape subtape2) pairs)) freezeRet :: Descr env sto - -> Ret env sto t + -> Ret env sto (D2 t) t -> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge")))) -freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) = +freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ _ t) = let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0 e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2 + tContribs = tTup (slistMap fromSMTy (subList (d2eM (select SMerge descr)) sub)) + library = #d (auto1 @(D2 t)) + &. #tape (subList (bindingsBinds e0) subtape) + &. #shbinds (bindingsBinds e0) + &. #d2ace (d2ace (select SAccum descr)) + &. #tl (desD1E descr) + &. #contribs (SCons tContribs SNil) in letBinds e0' $ EPair ext (weakenExpr wInsertD2Ac e1) - (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #tape (subList (bindingsBinds e0) subtape) - &. #shbinds (bindingsBinds e0) - &. #d2ace (d2ace (select SAccum descr)) - &. #tl (desD1E descr)) + (ELet ext (weakenExpr (autoWeak library (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl) (#shbinds :++: #d :++: #d2ace :++: #tl)) e2') $ - expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ)) + expandSubenvZeros + (autoWeak library #tl (#contribs :++: #shbinds :++: #d :++: #d2ace :++: #tl) + .> wUndoSubenv (subenvD1E (selectSub SMerge descr))) + (select SMerge descr) sub (EVar ext tContribs IZ)) ---------------------------- THE CHAD TRANSFORMATION --------------------------- -drev :: forall env sto t. +drev :: forall env sto sd t. (?config :: CHADConfig) => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) - -> Expr ValId env t -> Ret env sto t -drev des accumMap = \case + -> Sparse (D2 t) sd + -> Expr ValId env t -> Ret env sto sd t +drev des _ sd | isAbsent sd = + \e -> + Ret BTop + SETop + (drevPrimal des e) + (subenvNone (d2e (select SMerge des))) + (ENil ext) +drev _ _ SpAbsent = error "Absent should be isAbsent" + +drev des accumMap (SpSparse sd) = + \e -> + case drev des accumMap sd e of { Ret e0 subtape e1 sub e2 -> + subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> + Ret e0 + subtape + e1 + sub' + (emaybe (EVar ext (STMaybe (applySparse sd (d2 (typeOf e)))) IZ) + (inj2 (ENil ext)) + (inj1 (weakenExpr (WCopy WSink) e2))) + } + +drev des accumMap sd = \case EVar _ t i -> case conv2Idx des i of Idx2Ac accI -> Ret BTop SETop (EVar ext (d1 t) (conv1Idx i)) - (subenvNone (select SMerge des)) - (EAccum ext (d2M t) SAPHere (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) + (subenvNone (d2e (select SMerge des))) + (let ty = applySparse sd (d2M t) + in EAccum ext (d2M t) SAPHere (ENil ext) sd (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) Idx2Me tupI -> Ret BTop SETop (EVar ext (d1 t) (conv1Idx i)) - (subenvOnehot (select SMerge des) tupI) - (EPair ext (ENil ext) (EVar ext (d2 t) IZ)) + (subenvOnehot (d2e (select SMerge des)) tupI sd) + (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t)) IZ)) Idx2Di _ -> Ret BTop SETop (EVar ext (d1 t) (conv1Idx i)) - (subenvNone (select SMerge des)) + (subenvNone (d2e (select SMerge des))) (ENil ext) ELet _ (rhs :: Expr _ _ a) body - | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des accumMap rhs - , ChosenStorage storage <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge - , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) body + | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge + , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body + , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) rhs1 subRHS rhs2 <- drev des accumMap sdBody rhs , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 - , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) - , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> - subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> - let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in + , Refl <- lemAppendAssoc @body_shbinds @'[D1 a] @rhs_shbinds + , Refl <- lemAppendAssoc @body_shbinds @(D1 a : rhs_shbinds) @(D1E env) + , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) + -> + subenvPlus SF SF (d2eM (select SMerge des)) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> + let bodyResType = STPair (contribTupTy des subBody) (applySparse sdBody (d2 (typeOf rhs))) in Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') - (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody) + (subenvConcat subtapeRHS subtapeBody) (weakenExpr wbody0' body1) subBoth - (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #body (subList (bindingsBinds body0) subtapeBody) + (ELet ext (weakenExpr (autoWeak (#d (auto1 @sd) + &. #body (subList (bindingsBinds body0 `sappend` SCons (d1 (typeOf rhs)) SNil) subtapeBody) &. #rhs (subList (bindingsBinds rhs0) subtapeRHS) &. #tl (d2ace (select SAccum des))) (#d :++: #body :++: #tl) @@ -637,204 +807,225 @@ drev des accumMap = \case (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $ plus_RHS_Body - (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ) + (EVar ext (contribTupTy des subRHS) IZ) (EFst ext (EVar ext bodyResType (IS IZ)))) EPair _ a b - | Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) - <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil - , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) -> - subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B -> + | SpPair sd1 sd2 <- sd + , Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) + <- retConcat des $ toSingleRet (drev des accumMap sd1 a) `SCons` toSingleRet (drev des accumMap sd2 b) `SCons` SNil + , let dt = STPair (applySparse sd1 (d2 (typeOf a))) (applySparse sd2 (d2 (typeOf b))) -> + subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B -> Ret binds subtape (EPair ext a1 b1) subBoth - (EMaybe ext - (zeroTup (subList (select SMerge des) subBoth)) - (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) - (weakenExpr (WCopy (wSinks' @[_,_])) a2)) $ - ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ))) - (weakenExpr (WCopy (wSinks' @[_,_,_])) b2)) $ - plus_A_B - (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ)) - (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ)) - (EVar ext (STMaybe (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ)) + (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) + (weakenExpr (WCopy WSink) a2)) $ + ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ))) + (weakenExpr (WCopy (WSink .> WSink)) b2)) $ + plus_A_B + (EVar ext (contribTupTy des subA) (IS IZ)) + (EVar ext (contribTupTy des subB) IZ)) EFst _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , STPair t1 t2 <- typeOf e -> + | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair sd SpAbsent) e + , STPair t1 _ <- typeOf e -> Ret e0 subtape (EFst ext e1) sub - (ELet ext (EJust ext (EPair ext (EVar ext (d2 t1) IZ) (ezeroD2 t2))) $ + (ELet ext (EPair ext (EVar ext (applySparse sd (d2 t1)) IZ) (ENil ext)) $ weakenExpr (WCopy WSink) e2) ESnd _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , STPair t1 t2 <- typeOf e -> + | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair SpAbsent sd) e + , STPair _ t2 <- typeOf e -> Ret e0 subtape (ESnd ext e1) sub - (ELet ext (EJust ext (EPair ext (ezeroD2 t1) (EVar ext (d2 t2) IZ))) $ + (ELet ext (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t2)) IZ)) $ weakenExpr (WCopy WSink) e2) - ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext) + -- Don't need to handle ENil, because its cotangent is always absent! + -- ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (d2e (select SMerge des))) (ENil ext) EInl _ t2 e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> + | SpLEither sd1 sd2 <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap sd1 e -> + subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> Ret e0 subtape (EInl ext (d1 t2) e1) - sub + sub' (ELCase ext - (EVar ext (STLEither (d2 (typeOf e)) (d2 t2)) IZ) - (zeroTup (subList (select SMerge des) sub)) - (weakenExpr (WCopy WSink) e2) - (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr")) + (EVar ext (STLEither (applySparse sd1 (d2 (typeOf e))) (applySparse sd2 (d2 t2))) IZ) + (inj2 $ ENil ext) + (inj1 $ weakenExpr (WCopy WSink) e2) + (EError ext (contribTupTy des sub') "inl<-dinr")) EInr _ t1 e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> + | SpLEither sd1 sd2 <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap sd2 e -> + subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> Ret e0 subtape (EInr ext (d1 t1) e1) - sub + sub' (ELCase ext - (EVar ext (STLEither (d2 t1) (d2 (typeOf e))) IZ) - (zeroTup (subList (select SMerge des) sub)) - (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl") - (weakenExpr (WCopy WSink) e2)) + (EVar ext (STLEither (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 (typeOf e)))) IZ) + (inj2 $ ENil ext) + (EError ext (contribTupTy des sub') "inr<-dinl") + (inj1 $ weakenExpr (WCopy WSink) e2)) ECase _ e (a :: Expr _ _ t) b - | STEither t1 t2 <- typeOf e - , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap e + | STEither (t1 :: STy a) (t2 :: STy b) <- typeOf e , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge , let (bindids1, bindids2) = validSplitEither (extOf e) - , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des accumMap t1 storage1 bindids1 a - , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des accumMap t2 storage2 bindids2 b + , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA sd1 a2 + <- drevScoped des accumMap t1 storage1 bindids1 sd a + , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB sd2 b2 + <- drevScoped des accumMap t2 storage2 bindids2 sd b + , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap (SpLEither sd1 sd2) e , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) - , let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA) - , let tapeB = tapeTy (subList (bindingsBinds b0) subtapeB) - , let collectA = bindingsCollectTape a0 subtapeA - , let collectB = bindingsCollectTape b0 subtapeB + , let subtapeListA = subList (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA + , let subtapeListB = subList (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB + , let tapeA = tapeTy subtapeListA + , let tapeB = tapeTy subtapeListB + , let collectA = bindingsCollectTape @_ @_ @(Append rhs_a_binds (D1 a : Append e_binds (D1E env))) + (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA + , let collectB = bindingsCollectTape @_ @_ @(Append rhs_b_binds (D1 b : Append e_binds (D1E env))) + (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB) , let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0 , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0 + , Refl <- lemAppendNil @(Append rhs_a_binds '[D1 a]) + , Refl <- lemAppendNil @(Append rhs_b_binds '[D1 b]) + , Refl <- lemAppendAssoc @rhs_a_binds @'[D1 a] @(D1E env) + , Refl <- lemAppendAssoc @rhs_b_binds @'[D1 b] @(D1E env) + , let wa0'' = wa0' .> wCopies (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) (WClosed @(D1E env)) + , let wb0'' = wb0' .> wCopies (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) (WClosed @(D1E env)) -> - subenvPlus (select SMerge des) subA subB $ \subAB sAB_A sAB_B _ -> - subenvPlus (select SMerge des) subAB subE $ \subOut _ _ plus_AB_E -> - let tCaseRet = STPair (tTup (d2e (subList (select SMerge des) subAB))) (STLEither (d2 t1) (d2 t2)) in + subenvPlus ST ST (d2eM (select SMerge des)) subA subB $ \subAB (Inj sAB_A) (Inj sAB_B) _ -> + subenvPlus SF SF (d2eM (select SMerge des)) subAB subE $ \subOut _ _ plus_AB_E -> Ret (e0 `BPush` (tPrimal, ECase ext e1 - (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0')))) - (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0')))))) - (SEYes subtapeE) + (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0'')))) + (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0'')))))) + (SEYesR subtapeE) (EFst ext (EVar ext tPrimal IZ)) subOut - (ELet ext + (elet (ECase ext (ESnd ext (EVar ext tPrimal (IS IZ))) - (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds a0) subtapeA) IZ + (let (rebinds, prerebinds) = reconstructBindings subtapeListA IZ in letBinds rebinds $ ELet ext - (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_tape : D2 t : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend (subList (bindingsBinds a0) subtapeA) prerebinds) @> IS IZ)) $ - ELet ext - (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #ta0 (subList (bindingsBinds a0) subtapeA) + (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_a_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListA prerebinds) @> IS IZ)) $ + elet + (weakenExpr (autoWeak (#d (auto1 @sd) + &. #ta0 subtapeListA &. #prea0 prerebinds - &. #recon (tapeA `SCons` d2 (typeOf a) `SCons` SNil) + &. #recon (tapeA `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil) &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE) &. #tl (d2ace (select SAccum des))) (#d :++: #ta0 :++: #tl) (#d :++: (#ta0 :++: #prea0) :++: #recon :++: #binds :++: #tl)) a2) $ - EPair ext - (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $ - EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ)) - (ELInl ext (d2 t2) - (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ)))) - (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds b0) subtapeB) IZ + EPair ext (sAB_A $ EFst ext (evar IZ)) + (ELInl ext (applySparse sd2 (d2 t2)) (ESnd ext (evar IZ)))) + (let (rebinds, prerebinds) = reconstructBindings subtapeListB IZ in letBinds rebinds $ ELet ext - (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_tape : D2 t : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend (subList (bindingsBinds b0) subtapeB) prerebinds) @> IS IZ)) $ - ELet ext - (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #tb0 (subList (bindingsBinds b0) subtapeB) + (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_b_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListB prerebinds) @> IS IZ)) $ + elet + (weakenExpr (autoWeak (#d (auto1 @sd) + &. #tb0 subtapeListB &. #preb0 prerebinds - &. #recon (tapeB `SCons` d2 (typeOf a) `SCons` SNil) + &. #recon (tapeB `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil) &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE) &. #tl (d2ace (select SAccum des))) (#d :++: #tb0 :++: #tl) (#d :++: (#tb0 :++: #preb0) :++: #recon :++: #binds :++: #tl)) b2) $ - EPair ext - (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $ - EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ)) - (ELInr ext (d2 t1) - (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))))) $ - ELet ext - (ELet ext (ESnd ext (EVar ext tCaseRet IZ)) $ - weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $ + EPair ext (sAB_B $ EFst ext (evar IZ)) + (ELInr ext (applySparse sd1 (d2 t1)) (ESnd ext (evar IZ))))) $ plus_AB_E - (EFst ext (EVar ext tCaseRet (IS IZ))) - (EVar ext (tTup (d2e (subList (select SMerge des) subE))) IZ)) + (EFst ext (evar IZ)) + (ELet ext (ESnd ext (evar IZ)) $ + weakenExpr (WCopy (wSinks' @[_,_,_])) e2)) EConst _ t val -> Ret BTop SETop (EConst ext t val) - (subenvNone (select SMerge des)) + (subenvNone (d2e (select SMerge des))) (ENil ext) EOp _ op e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> + | Ret e0 subtape e1 sub e2 <- drev des accumMap (spDense (d2M (opt1 op))) e -> case d2op op of Linear d2opfun -> Ret e0 subtape (d1op op e1) sub - (ELet ext (d2opfun (EVar ext (d2 (opt2 op)) IZ)) + (ELet ext (d2opfun (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ))) (weakenExpr (WCopy WSink) e2)) Nonlinear d2opfun -> Ret (e0 `BPush` (d1 (typeOf e), e1)) - (SEYes subtape) + (SEYesR subtape) (d1op op $ EVar ext (d1 (typeOf e)) IZ) sub (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ)) - (EVar ext (d2 (opt2 op)) IZ)) + (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ))) (weakenExpr (WCopy (wSinks' @[_,_])) e2)) - ECustom _ _ _ storety _ pr du a b + ECustom _ _ tb storety srce pr du a b -- allowed to ignore a2 because 'a' is the part of the input that is inactive - | Rets binds subtape (RetPair a1 _ _ `SCons` RetPair b1 bsub b2 `SCons` SNil) - <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil -> - Ret (binds `BPush` (typeOf a1, a1) - `BPush` (typeOf b1, weakenExpr WSink b1) - `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) - `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ))) - (SEYes (SENo (SENo (SENo subtape)))) - (EFst ext (EVar ext (typeOf pr) (IS IZ))) - bsub - (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ - weakenExpr (WCopy (WSink .> WSink)) b2) - - -- TODO: compute primal in direct form here instead of taking the redundantly inefficient CHAD primal + | Ret b0 bsubtape b1 bsub b2 <- drev des accumMap (spDense (d2M tb)) b -> + case isDense (d2M (typeOf srce)) sd of + Just Refl -> + Ret (b0 `BPush` (d1 (typeOf a), weakenExpr (sinkWithBindings b0) (drevPrimal des a)) + `BPush` (typeOf b1, weakenExpr WSink b1) + `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) + `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ))) + (SEYesR (SENo (SENo (SENo bsubtape)))) + (EFst ext (EVar ext (typeOf pr) (IS IZ))) + bsub + (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ + weakenExpr (WCopy (WSink .> WSink)) b2) + + Nothing -> + Ret (b0 `BPush` (d1 (typeOf a), weakenExpr (sinkWithBindings b0) (drevPrimal des a)) + `BPush` (typeOf b1, weakenExpr WSink b1) + `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr))) + (SEYesR (SENo (SENo bsubtape))) + (EFst ext (EVar ext (typeOf pr) IZ)) + bsub + (ELet ext (ESnd ext (EVar ext (typeOf pr) (IS IZ))) $ -- tape + ELet ext (expandSparse (typeOf srce) sd -- expanded incoming cotangent + (EFst ext (EVar ext (typeOf pr) (IS (IS IZ)))) + (EVar ext (applySparse sd (d2 (typeOf srce))) (IS IZ))) $ + ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ + weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) b2) + ERecompute _ e -> deleteUnused (descrList des) (occCountAll e) $ \usedSub -> let smallE = unsafeWeakenWithSubenv usedSub e in subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> - case drev usedDes (VarMap.subMap subAccumUsed accumMap) smallE of { Ret e0 subtape e1 sub e2 -> + case drev usedDes (VarMap.subMap subAccumUsed accumMap) sd smallE of { Ret e0 subtape _ sub e2 -> + let subMergeUsed' = subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E subMergeUsed) in Ret (collectBindings (desD1E des) subD1eUsed) (subenvAll (desD1E usedDes)) - (weakenExpr (wRaiseAbove (desD1E usedDes) (desD1E des)) $ letBinds e0 e1) - (subenvCompose subMergeUsed sub) + (weakenExpr (wSinks (desD1E usedDes)) $ drevPrimal des e) + (subenvCompose subMergeUsed' sub) (letBinds (fst (weakenBindings weakenExpr (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $ weakenExpr - (autoWeak (#d (auto1 @(D2 t)) + (autoWeak (#d (auto1 @sd) &. #shbinds (bindingsBinds e0) &. #tape (subList (bindingsBinds e0) subtape) &. #d1env (desD1E usedDes) @@ -849,128 +1040,130 @@ drev des accumMap = \case Ret BTop SETop (EError ext (d1 t) s) - (subenvNone (select SMerge des)) + (subenvNone (d2e (select SMerge des))) (ENil ext) EConstArr _ n t val -> Ret BTop SETop (EConstArr ext n t val) - (subenvNone (select SMerge des)) + (subenvNone (d2e (select SMerge des))) (ENil ext) EBuild _ (ndim :: SNat ndim) she (orige :: Expr _ _ eltty) - | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des accumMap she -- allowed to ignore she2 here because she has a discrete result + | SpArr @_ @sdElt sdElt <- sd , let eltty = typeOf orige , shty :: STy shty <- tTup (sreplicate ndim tIx) , Refl <- indexTupD1Id ndim -> deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') -> - let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in - subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> - accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> + let e = unsafeWeakenWithSubenv (SEYesR usedSub) orige in + subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> + accumPromote sdElt usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in - case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> + case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro sdElt e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> case assertSubenvEmpty sub of { Refl -> + case lemAppendNil @e_binds of { Refl -> let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in - let collectexpr = bindingsCollectTape e0 subtapeE in - Ret (BTop `BPush` (shty, letBinds she0 she1) - `BPush` (STArr ndim (STPair (d1 eltty) tapety) - ,EBuild ext ndim - (EVar ext shty IZ) - (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil) - &. #sh (shty `SCons` SNil) - &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes)) - (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#ix :++: #sh :++: #d1env)) - e0)) $ - let w = autoWeak (#ix (shty `SCons` SNil) - &. #sh (shty `SCons` SNil) - &. #e0 (bindingsBinds e0) - &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes)) - (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#e0 :++: #ix :++: #sh :++: #d1env) - in EPair ext (weakenExpr w e1) (collectexpr w))) - `BPush` (STArr ndim tapety, emap (ESnd ext (EVar ext (STPair (d1 eltty) tapety) IZ)) - (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ))) - (SEYes (SENo (SEYes SETop))) - (emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ)) - (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) - (subenvCompose subMergeUsed proSub) - (let sinkOverEnvPro = wSinks @(TArr ndim (D2 eltty) : D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in - EMaybe ext - (zeroTup envPro) - (ESnd ext $ - uninvertTup (d2e envPro) (STArr ndim STNil) $ - makeAccumulators @_ @_ @(TArr ndim TNil) envPro $ - EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS (IS IZ)))) $ - -- the cotangent for this element - ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ)) - (EVar ext shty IZ)) $ - -- the tape for this element - ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS (IS IZ))) - (EVar ext shty (IS IZ))) $ - let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ - in letBinds rebinds $ - weakenExpr (autoWeak (#d (auto1 @(D2 eltty)) - &. #pro (d2ace envPro) - &. #etape (subList (bindingsBinds e0) subtapeE) - &. #prerebinds prerebinds - &. #tape (auto1 @(Tape e_tape)) - &. #ix (auto1 @shty) - &. #darr (auto1 @(TArr ndim (D2 eltty))) - &. #mdarr (auto1 @(TMaybe (TArr ndim (D2 eltty)))) - &. #tapearr (auto1 @(TArr ndim (Tape e_tape))) - &. #sh (auto1 @shty) - &. #d2acUsed (d2ace (select SAccum usedDes)) - &. #d2acEnv (d2ace (select SAccum des))) - (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) - ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #mdarr :++: #tapearr :++: #sh :++: #d2acEnv) - .> wPro (subList (bindingsBinds e0) subtapeE)) - e2) - (EVar ext (d2 (STArr ndim eltty)) IZ)) - }} + let collectexpr = bindingsCollectTape (bindingsBinds e0) subtapeE in + let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in + let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in + Ret (mergePrimalBindings + `BPush` (shty, weakenExpr (wSinks (d1e envPro)) (drevPrimal des she)) + `BPush` (STArr ndim (STPair (d1 eltty) tapety) + ,EBuild ext ndim + (EVar ext shty IZ) + (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil) + &. #sh (shty `SCons` SNil) + &. #propr (d1e envPro) + &. #d1env (desD1E des) + &. #d1env' (desD1E usedDes)) + (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#ix :++: #sh :++: #propr :++: #d1env)) + e0)) $ + let w = autoWeak (#ix (shty `SCons` SNil) + &. #sh (shty `SCons` SNil) + &. #e0 (bindingsBinds e0) + &. #propr (d1e envPro) + &. #d1env (desD1E des) + &. #d1env' (desD1E usedDes)) + (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#e0 :++: #ix :++: #sh :++: #propr :++: #d1env) + w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env')) + in EPair ext (weakenExpr w e1) (collectexpr w'))) + `BPush` (STArr ndim tapety, emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ))) + (SEYesR (SENo (SEYesR (subenvAll (d1e envPro))))) + (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) + (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) + (let sinkOverEnvPro = wSinks @(sd : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace envPro) in + ESnd ext $ + uninvertTup (d2e envPro) (STArr ndim STNil) $ + makeAccumulators @_ @_ @(TArr ndim TNil) (WSink .> WSink .> WSink .> wRaiseAbove (d1e envPro) (d2ace (select SAccum des))) envPro $ + EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $ + -- the cotangent for this element + ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ)) + (EVar ext shty IZ)) $ + -- the tape for this element + ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) + (EVar ext shty (IS IZ))) $ + let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ + in letBinds rebinds $ + weakenExpr (autoWeak (#d (auto1 @sdElt) + &. #pro (d2ace envPro) + &. #etape (subList (bindingsBinds e0) subtapeE) + &. #prerebinds prerebinds + &. #tape (auto1 @(Tape e_tape)) + &. #ix (auto1 @shty) + &. #darr (auto1 @(TArr ndim sdElt)) + &. #tapearr (auto1 @(TArr ndim (Tape e_tape))) + &. #sh (auto1 @shty) + &. #propr (d1e envPro) + &. #d2acUsed (d2ace (select SAccum usedDes)) + &. #d2acEnv (d2ace (select SAccum des))) + (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) + ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #propr :++: #d2acEnv) + .> wPro (subList (bindingsBinds e0) subtapeE)) + e2) + }}} EUnit _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> + | SpArr sdElt <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap sdElt e -> Ret e0 subtape (EUnit ext e1) sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $ - weakenExpr (WCopy (WSink .> WSink)) e2) - (EVar ext (STMaybe (STArr SZ (d2 (typeOf e)))) IZ)) + (ELet ext (EIdx0 ext (EVar ext (STArr SZ (applySparse sdElt (d2 (typeOf e)))) IZ)) $ + weakenExpr (WCopy WSink) e2) EReplicate1Inner _ en e - -- We're allowed to ignore en2 here because the output of 'ei' is discrete. - | Rets binds subtape (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil) - <- retConcat des $ drev des accumMap en `SCons` drev des accumMap e `SCons` SNil + -- We're allowed to differentiate 'en' as primal-only here because its output is discrete. + | SpArr sdElt <- sd , let STArr ndim eltty = typeOf e -> - Ret binds - subtape - (EReplicate1Inner ext en1 e1) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext (EFold1Inner ext Commut - (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) - (ezeroD2 eltty) - (EVar ext (STArr (SS ndim) (d2 eltty)) IZ))) $ - weakenExpr (WCopy (WSink .> WSink)) e2) - (EVar ext (d2 (STArr (SS ndim) eltty)) IZ)) + -- This pessimistic sparsity union is because the array might have been empty, in which case we need to generate a zero. + sparsePlusS ST ST (d2M eltty) sdElt SpAbsent $ \sdElt' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sdElt') e of { Ret binds subtape e1 sub e2 -> + Ret binds + subtape + (EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1) + sub + (ELet ext (EFold1Inner ext Commut + (sparsePlus (d2M eltty) sdElt' + (EVar ext (applySparse sdElt' (d2 eltty)) (IS IZ)) + (EVar ext (applySparse sdElt' (d2 eltty)) IZ)) + (inj2 (ENil ext)) + (emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $ + weakenExpr (WCopy WSink) e2) + } EIdx0 _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e + | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd) e , STArr _ t <- typeOf e -> Ret e0 subtape (EIdx0 ext e1) sub - (ELet ext (EJust ext (EUnit ext (EVar ext (d2 t) IZ))) $ - weakenExpr (WCopy WSink) e2) + (ELet ext (EUnit ext (EVar ext (applySparse sd (d2 t)) IZ)) $ + weakenExpr (WCopy WSink) e2) EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead" {- @@ -981,7 +1174,7 @@ drev des accumMap = \case , STArr (SS n) eltty <- typeOf e -> Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1) `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ))) - (SEYes (SENo subtape)) + (SEYesR (SENo subtape)) (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ)) (weakenExpr (WSink .> WSink) ei1)) sub @@ -992,57 +1185,58 @@ drev des accumMap = \case -} EIdx _ e ei - -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. - | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) - <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil - , STArr n eltty <- typeOf e + -- We're allowed to differentiate ei as primal because its output is discrete. + | STArr n eltty <- typeOf e , Refl <- indexTupD1Id n - , Refl <- lemZeroInfoD2 eltty - , let tIxN = tTup (sreplicate n tIx) -> - Ret (binds `BPush` (STArr n (d1 eltty), e1) - `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ)) - `BPush` (tIxN, weakenExpr (WSink .> WSink) ei1)) - (SEYes (SEYes (SENo subtape))) - (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) IZ)) - sub - (ELet ext (EOneHot ext (d2M (STArr n eltty)) (SAPJust (SAPArrIdx SAPHere)) - (EPair ext (EPair ext (EVar ext tIxN (IS IZ)) - (EBuild ext n (EVar ext tIxN (IS (IS IZ))) (ENil ext))) - (ENil ext)) - (EVar ext (d2 eltty) IZ)) $ - weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + , let tIxN = tTup (sreplicate n tIx) -> + sparsePlusS ST ST (d2M eltty) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sd') e of { Ret binds subtape e1 sub e2 -> + Ret (binds `BPush` (STArr n (d1 eltty), e1) + `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ)) + `BPush` (tIxN, weakenExpr (WSink .> WSink .> wSinks (bindingsBinds binds)) (drevPrimal des ei))) + (SEYesR (SEYesR (SENo subtape))) + (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) IZ)) + sub + (ELet ext + (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty))) + (SAPArrIdx SAPHere) + (EPair ext + (EPair ext (EVar ext tIxN (IS IZ)) + (EBuild ext n (EVar ext tIxN (IS (IS IZ))) $ + makeZeroInfo (applySparse sd' (d2M eltty)) (inj2 (ENil ext)))) + (ENil ext)) + (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + } EShape _ e - -- Allowed to ignore e2 here because the output of EShape is discrete, - -- hence we'd be passing a zero cotangent to e2 anyway. - | Ret e0 subtape e1 _ _ <- drev des accumMap e - , STArr n _ <- typeOf e + -- Allowed to differentiate e as primal because the output of EShape is + -- discrete, hence we'd be passing a zero cotangent to e anyway. + | STArr n _ <- typeOf e , Refl <- indexTupD1Id n -> - Ret e0 - subtape - (EShape ext e1) - (subenvNone (select SMerge des)) + Ret BTop + SETop + (EShape ext (drevPrimal des e)) + (subenvNone (d2eM (select SMerge des))) (ENil ext) ESum1Inner _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e + | SpArr sd' <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e , STArr (SS n) t <- typeOf e -> Ret (e0 `BPush` (STArr (SS n) t, e1) `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ))) - (SEYes (SENo subtape)) + (SEYesR (SENo subtape)) (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ))) sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext (EReplicate1Inner ext - (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS (IS IZ)))) - (EVar ext (STArr n (d2 t)) IZ))) $ - weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) - (EVar ext (d2 (STArr n t)) IZ)) + (ELet ext (EReplicate1Inner ext + (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) + (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) - EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e - EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e + EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e + EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e -- These should be the next to be implemented, I think EFold1Inner{} -> err_unsupported "EFold1Inner" @@ -1056,8 +1250,8 @@ drev des accumMap = \case ELCase{} -> err_unsupported "ELCase" EWith{} -> err_accum - EAccum{} -> err_accum EZero{} -> err_monoid + EDeepZero{} -> err_monoid EPlus{} -> err_monoid EOneHot{} -> err_monoid @@ -1066,94 +1260,116 @@ drev des accumMap = \case err_monoid = error "Monoid operations unsupported in the source program" err_unsupported s = error $ "CHAD: unsupported " ++ s - deriv_extremum :: ScalIsNumeric t' ~ True - => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t'))) - -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t')) - deriv_extremum extremum e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , at@(STArr (SS n) t@(STScal st)) <- typeOf e - , let at' = STArr n t - , let tIxN = tTup (sreplicate (SS n) tIx) = - Ret (e0 `BPush` (at, e1) - `BPush` (at', extremum (EVar ext at IZ))) - (SEYes (SEYes subtape)) - (EVar ext at' IZ) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext - (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS (IS IZ))))) $ - eif (EOp ext (OEq st) (EPair ext - (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ)) - (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ))))) - (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) - (ezeroD2 t))) $ - weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2) - (EVar ext (d2 at') IZ)) + contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs) + contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub)) + +deriv_extremum :: (?config :: CHADConfig, ScalIsNumeric t ~ True) + => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) + -> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) + -> Sparse (D2s t) sd + -> Expr ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t)) +deriv_extremum extremum des accumMap sd e + | at@(STArr (SS n) t@(STScal st)) <- typeOf e + , let at' = STArr n t + , let tIxN = tTup (sreplicate (SS n) tIx) = + sparsePlusS ST ST (d2M t) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sd') e of { Ret e0 subtape e1 sub e2 -> + Ret (e0 `BPush` (at, e1) + `BPush` (at', extremum (EVar ext at IZ))) + (SEYesR (SEYesR subtape)) + (EVar ext at' IZ) + sub + (ELet ext + (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $ + eif (EOp ext (OEq st) (EPair ext + (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ)) + (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ))))) + (inj1 $ EIdx ext (EVar ext (STArr n (applySparse sd (d2 t))) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) + (inj2 (ENil ext))) $ + weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + } data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s) -data RetScoped env0 sto a s t = - forall shbinds tapebinds env0Merge. +data RetScoped env0 sto a s sd t = + forall shbinds tapebinds contribs sa. RetScoped (Bindings Ex (D1E (a : env0)) shbinds) -- shared binds - (Subenv shbinds tapebinds) + (Subenv (Append shbinds '[D1 a]) tapebinds) (Ex (Append shbinds (D1E (a : env0))) (D1 t)) - (Subenv (Select env0 sto "merge") env0Merge) + (SubenvS (D2E (Select env0 sto "merge")) contribs) -- ^ merge contributions to the _enclosing_ merge environment - (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) - (If (s == "discr") (Tup (D2E env0Merge)) - (TPair (Tup (D2E env0Merge)) (D2 a)))) + (Sparse (D2 a) sa) + -- ^ contribution to the argument + (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) + (If (s == "discr") (Tup contribs) + (TPair (Tup contribs) sa))) -- ^ the merge contributions, plus the cotangent to the argument -- (if there is any) -deriving instance Show (RetScoped env0 sto a s t) +deriving instance Show (RetScoped env0 sto a s sd t) -drevScoped :: forall a s env sto t. +drevScoped :: forall a s env sto sd t. (?config :: CHADConfig) => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) -> STy a -> Storage s -> Maybe (ValId a) + -> Sparse (D2 t) sd -> Expr ValId (a : env) t - -> RetScoped env sto a s t -drevScoped des accumMap argty argsto argids expr = case argsto of + -> RetScoped env sto a s sd t +drevScoped des accumMap argty argsto argids sd expr = case argsto of SMerge - | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr -> + | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr + , Refl <- lemAppendNil @tapebinds -> case sub of - SEYes sub' -> RetScoped e0 subtape e1 sub' e2 - SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (ezeroD2 argty)) + SEYes sp sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' sp e2 + SENo sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' SpAbsent (EPair ext e2 (ENil ext)) SAccum - | Just (VIArr i _) <- argids + | chcSmartWith ?config + , Just (VIArr i _) <- argids , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap , Just Refl <- testEquality foundTy (STAccum (d2M argty)) - , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) expr -> - RetScoped e0 subtape e1 sub $ + , Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) sd expr + , Refl <- lemAppendNil @tapebinds -> + -- Our contribution to the binding's cotangent _here_ is zero (absent), + -- because we're contributing to an earlier binding of the same value + -- instead. + RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent $ let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in ELet ext (EVar ext (STAccum (d2M argty)) (WSink .> wtapebinds @> idx)) $ - weakenExpr (autoWeak (#d (auto1 @(D2 t)) + weakenExpr (autoWeak (#d (auto1 @sd) &. #body (subList (bindingsBinds e0) subtape) &. #ac (auto1 @(TAccum (D2 a))) &. #tl (d2ace (select SAccum des))) (#d :++: #body :++: #ac :++: #tl) (#ac :++: #d :++: #body :++: #tl)) - -- Our contribution to the binding's cotangent _here_ is - -- zero, because we're contributing to an earlier binding - -- of the same value instead. - (EPair ext e2 (ezeroD2 argty)) + (EPair ext e2 (ENil ext)) | let accumMap' = case argids of Just (VIArr i _) -> VarMap.insert i (STAccum (d2M argty)) IZ (VarMap.sink1 accumMap) _ -> VarMap.sink1 accumMap - , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' expr -> - RetScoped e0 subtape e1 sub $ - EWith ext (d2M argty) (ezeroD2 argty) $ - weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #body (subList (bindingsBinds e0) subtape) - &. #ac (auto1 @(TAccum (D2 a))) - &. #tl (d2ace (select SAccum des))) + , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' sd expr -> + let library = #d (auto1 @sd) + &. #p (auto1 @(D1 a)) + &. #body (subList (bindingsBinds e0) subtape) + &. #ac (auto1 @(TAccum (D2 a))) + &. #tl (d2ace (select SAccum des)) + in + RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub (spDense (d2M argty)) $ + let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in + EWith ext (d2M argty) (EDeepZero ext (d2M argty) (d2deepZeroInfo argty (EVar ext (d1 argty) primalIdx))) $ + weakenExpr (autoWeak library (#d :++: #body :++: #ac :++: #tl) - (#ac :++: #d :++: #body :++: #tl)) + (#ac :++: #d :++: (#body :++: #p) :++: #tl)) e2 SDiscr - | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr -> - RetScoped e0 subtape e1 sub e2 + | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr + , Refl <- lemAppendNil @tapebinds -> + RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2 + +-- TODO: proper primal-only transform that doesn't depend on D1 = Id +drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t) +drevPrimal des e + | Refl <- d1Identity (typeOf e) + , Refl <- d1eIdentity (descrList des) + = mapExt (const ext) e |