summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-06-21 09:57:45 +0200
committerTom Smeding <tom@tomsmeding.com>2025-06-21 09:57:45 +0200
commitb5ed3d2fcc249cb410b9e86d25e9ef808c6dba97 (patch)
tree66383b16d5d95f939aaa165a783dbbfd99a57fe3 /src/CHAD.hs
parent8bbc2d2867e3d0a4a1f2810b40e92175779822e1 (diff)
parenta4b3eb76acbec30ffeae119a4dc6e4c9f64396fe (diff)
Merge branch 'sparse'HEADmaster
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs1068
1 files changed, 642 insertions, 426 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index df792ce..143376a 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -11,6 +11,7 @@
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
@@ -33,7 +34,6 @@ module CHAD (
import Data.Functor.Const
import Data.Some
-import Data.Type.Bool (If)
import Data.Type.Equality (type (==), testEquality)
import GHC.Stack (HasCallStack)
@@ -42,6 +42,7 @@ import AST
import AST.Bindings
import AST.Count
import AST.Env
+import AST.Sparse
import AST.Weaken.Auto
import CHAD.Accum
import CHAD.EnvDescr
@@ -62,15 +63,21 @@ tapeTy :: SList STy binds -> STy (Tape binds)
tapeTy SNil = STNil
tapeTy (SCons t ts) = STPair t (tapeTy ts)
-bindingsCollectTape :: Bindings f env binds -> Subenv binds tapebinds
- -> Append binds env :> env2 -> Ex env2 (Tape tapebinds)
-bindingsCollectTape BTop SETop _ = ENil ext
-bindingsCollectTape (BPush binds (t, _)) (SEYes sub) w =
+bindingsCollectTape :: SList STy binds -> Subenv binds tapebinds
+ -> binds :> env2 -> Ex env2 (Tape tapebinds)
+bindingsCollectTape SNil SETop _ = ENil ext
+bindingsCollectTape (t `SCons` binds) (SEYesR sub) w =
EPair ext (EVar ext t (w @> IZ))
(bindingsCollectTape binds sub (w .> WSink))
-bindingsCollectTape (BPush binds _) (SENo sub) w =
+bindingsCollectTape (_ `SCons` binds) (SENo sub) w =
bindingsCollectTape binds sub (w .> WSink)
+-- bindingsCollectTape' :: forall f env binds tapebinds env2. Bindings f env binds -> Subenv binds tapebinds
+-- -> Append binds env :> env2 -> Ex env2 (Tape tapebinds)
+-- bindingsCollectTape' binds sub w
+-- | Refl <- lemAppendNil @binds
+-- = bindingsCollectTape (bindingsBinds binds) sub (w .> wCopies @_ @_ @'[] (bindingsBinds binds) (WClosed @env))
+
-- In order from large to small: i.e. in reverse order from what we want,
-- because in a Bindings, the head of the list is the bottom-most entry.
type family TapeUnfoldings binds where
@@ -227,26 +234,37 @@ data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
d2op :: SOp a t -> D2Op a t
d2op op = case op of
- OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EJust ext (EPair ext d d)
+ OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EPair ext d d
OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d ->
- EJust ext (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
- (EOp ext (OMul t) (EPair ext (EFst ext e) d)))
+ EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
+ (EOp ext (OMul t) (EPair ext (EFst ext e) d))
ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d
- OLt t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t)))
- OLe t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t)))
- OEq t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t)))
+ OLt t -> Linear $ \_ -> pairZero t
+ OLe t -> Linear $ \_ -> pairZero t
+ OEq t -> Linear $ \_ -> pairZero t
ONot -> Linear $ \_ -> ENil ext
- OAnd -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
- OOr -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ OAnd -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ OOr -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
OIf -> Linear $ \_ -> ENil ext
- ORound64 -> Linear $ \_ -> EConst ext STF64 0.0
+ ORound64 -> Linear $ \_ -> EZero ext (SMTScal STF64) (ENil ext)
OToFl64 -> Linear $ \_ -> ENil ext
ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d)
OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d)
OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d)
- OIDiv t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil)
- OMod t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ OIDiv t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ OMod t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
where
+ pairZero :: SScalTy a -> Ex env (D2 (TPair (TScal a) (TScal a)))
+ pairZero t = ziNil t $ EPair ext (EZero ext (d2M (STScal t)) (ENil ext))
+ (EZero ext (d2M (STScal t)) (ENil ext))
+ where
+ ziNil :: SScalTy a -> (ZeroInfo (D2s a) ~ TNil => r) -> r
+ ziNil STI32 k = k
+ ziNil STI64 k = k
+ ziNil STF32 k = k
+ ziNil STF64 k = k
+ ziNil STBool k = k
+
d2opUnArrangeInt :: SScalTy a
-> (D2s a ~ TScal a => D2Op (TScal a) t)
-> D2Op (TScal a) t
@@ -261,11 +279,11 @@ d2op op = case op of
-> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t)
-> D2Op (TPair (TScal a) (TScal a)) t
d2opBinArrangeInt ty float = case ty of
- STI32 -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
- STI64 -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ STI32 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ STI64 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
STF32 -> float
STF64 -> float
- STBool -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ STBool -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
floatingD2 :: ScalIsFloating a ~ True
=> SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r
@@ -293,7 +311,7 @@ conv1Idx (IS i) = IS (conv1Idx i)
data Idx2 env sto t
= Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t)))
- | Idx2Me (Idx (Select env sto "merge") t)
+ | Idx2Me (Idx (D2E (Select env sto "merge")) (D2 t))
| Idx2Di (Idx (Select env sto "discr") t)
conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t
@@ -314,64 +332,160 @@ conv2Idx (DPush des (_, _, SDiscr)) (IS i) =
Idx2Di j -> Idx2Di (IS j)
conv2Idx DTop i = case i of {}
-
------------------------------------- MONOIDS -----------------------------------
-
-zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0))
-zeroTup SNil = ENil ext
-zeroTup (t `SCons` env) = EPair ext (zeroTup env) (ezeroD2 t)
-
-
------------------------------------- SUBENVS -----------------------------------
-
-subenvPlus :: SList STy env
- -> Subenv env env1 -> Subenv env env2
- -> (forall env3. Subenv env env3
- -> Subenv env3 env1
- -> Subenv env3 env2
- -> (Ex exenv (Tup (D2E env1))
- -> Ex exenv (Tup (D2E env2))
- -> Ex exenv (Tup (D2E env3)))
+opt2UnSparse :: SOp a b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b)
+opt2UnSparse = go . opt2
+ where
+ go :: STy b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b)
+ go (STScal STI32) SpAbsent = \_ -> ENil ext
+ go (STScal STI64) SpAbsent = \_ -> ENil ext
+ go (STScal STF32) SpAbsent = \_ -> EZero ext (SMTScal STF32) (ENil ext)
+ go (STScal STF64) SpAbsent = \_ -> EZero ext (SMTScal STF64) (ENil ext)
+ go (STScal STBool) SpAbsent = \_ -> ENil ext
+ go (STScal STF32) SpScal = id
+ go (STScal STF64) SpScal = id
+ go STNil _ = \_ -> ENil ext
+ go (STPair t1 t2) (SpPair s1 s2) = \e -> eunPair e $ \_ e1 e2 -> EPair ext (go t1 s1 e1) (go t2 s2 e2)
+ go t _ = error $ "Primitive operations that return " ++ show t ++ " are scary"
+
+
+----------------------------------- SPARSITY -----------------------------------
+
+expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a)
+expandSparse t sp _ e | Just Refl <- isDense (d2M t) sp = e
+expandSparse t (SpSparse sp) epr e =
+ EMaybe ext
+ (EZero ext (d2M t) (d2zeroInfo t epr))
+ (expandSparse t sp (weakenExpr WSink epr) (EVar ext (applySparse sp (d2 t)) IZ))
+ e
+expandSparse t SpAbsent epr _ = EZero ext (d2M t) (d2zeroInfo t epr)
+expandSparse (STPair t1 t2) (SpPair s1 s2) epr e =
+ eunPair epr $ \w1 epr1 epr2 ->
+ eunPair (weakenExpr w1 e) $ \w2 e1 e2 ->
+ EPair ext (expandSparse t1 s1 (weakenExpr w2 epr1) e1)
+ (expandSparse t2 s2 (weakenExpr w2 epr2) e2)
+expandSparse (STEither t1 t2) (SpLEither s1 s2) epr e =
+ ELCase ext e
+ (EZero ext (d2M (STEither t1 t2)) (ENil ext))
+ (ECase ext (weakenExpr WSink epr)
+ (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ))))
+ (EError ext (d2 (STEither t1 t2)) "expspa r<-dl"))
+ (ECase ext (weakenExpr WSink epr)
+ (EError ext (d2 (STEither t1 t2)) "expspa l<-dr")
+ (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ)))))
+expandSparse (STLEither t1 t2) (SpLEither s1 s2) epr e =
+ ELCase ext e
+ (EZero ext (d2M (STEither t1 t2)) (ENil ext))
+ (ELCase ext (weakenExpr WSink epr)
+ (EError ext (d2 (STEither t1 t2)) "expspa ln<-dl")
+ (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ))))
+ (EError ext (d2 (STEither t1 t2)) "expspa lr<-dl"))
+ (ELCase ext (weakenExpr WSink epr)
+ (EError ext (d2 (STEither t1 t2)) "expspa ln<-dr")
+ (EError ext (d2 (STEither t1 t2)) "expspa ll<-dr")
+ (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ)))))
+expandSparse (STMaybe t) (SpMaybe s) epr e =
+ EMaybe ext
+ (ENothing ext (d2 t))
+ (let epr' = EMaybe ext (EError ext (d1 t) "expspa n<-dj") (EVar ext (d1 t) IZ) epr
+ in EJust ext (expandSparse t s (weakenExpr WSink epr') (EVar ext (applySparse s (d2 t)) IZ)))
+ e
+expandSparse (STArr _ t) (SpArr s) epr e =
+ ezipWith (expandSparse t s (EVar ext (d1 t) (IS IZ)) (EVar ext (applySparse s (d2 t)) IZ)) epr e
+expandSparse (STScal STF32) SpScal _ e = e
+expandSparse (STScal STF64) SpScal _ e = e
+expandSparse (STAccum{}) _ _ _ = error "accumulators not allowed in source program"
+
+subenvPlus :: SBool req1 -> SBool req2
+ -> SList SMTy env
+ -> SubenvS env env1 -> SubenvS env env2
+ -> (forall env3. SubenvS env env3
+ -> Injection req1 (Tup env1) (Tup env3)
+ -> Injection req2 (Tup env2) (Tup env3)
+ -> (forall e. Ex e (Tup env1) -> Ex e (Tup env2) -> Ex e (Tup env3))
-> r)
-> r
-subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext)
-subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
+-- don't destroy effects!
+subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\a b -> use a $ use b $ ENil ext)
+
+subenvPlus req1 req2 (SCons _ env) (SENo sub1) (SENo sub2) k =
+ subenvPlus req1 req2 env sub1 sub2 $ \sub3 s31 s32 pl ->
k (SENo sub3) s31 s32 pl
-subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 ->
- ELet ext e1 $
- EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
- (weakenExpr WSink e2))
- (ESnd ext (EVar ext (typeOf e1) IZ))
-subenvPlus (SCons _ env) (SENo sub1) (SEYes sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SEYes sub3) (SENo s31) (SEYes s32) $ \e1 e2 ->
- ELet ext e2 $
- EPair ext (pl (weakenExpr WSink e1)
- (EFst ext (EVar ext (typeOf e2) IZ)))
- (ESnd ext (EVar ext (typeOf e2) IZ))
-subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SEYes sub3) (SEYes s31) (SEYes s32) $ \e1 e2 ->
- ELet ext e1 $
- ELet ext (weakenExpr WSink e2) $
- EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ)))
- (EFst ext (EVar ext (typeOf e2) IZ)))
- (EPlus ext (d2M t)
- (ESnd ext (EVar ext (typeOf e1) (IS IZ)))
- (ESnd ext (EVar ext (typeOf e2) IZ)))
-
-expandSubenvZeros :: SList STy env0 -> Subenv env0 env0Merge -> Ex env (Tup (D2E env0Merge)) -> Ex env (Tup (D2E env0))
-expandSubenvZeros _ SETop _ = ENil ext
-expandSubenvZeros (SCons t ts) (SEYes sub) e =
- ELet ext e $
- let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ
- in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var)
-expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (ezeroD2 t)
-
-assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[]
+
+subenvPlus req1 SF (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k =
+ subenvPlus req1 SF env sub1 sub2 $ \sub3 minj13 _ pl ->
+ k (SEYes sp1 sub3)
+ (withInj minj13 $ \inj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) e1b)
+ Noinj
+ (\e1 e2 ->
+ ELet ext e1 $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
+ (weakenExpr WSink e2))
+ (ESnd ext (EVar ext (typeOf e1) IZ)))
+subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k
+ | Just zero1 <- cheapZero (applySparse sp1 t) =
+ subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl ->
+ k (SEYes sp1 sub3)
+ (withInj minj13 $ \inj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) e1b)
+ (Inj $ \e2 -> EPair ext (inj23 e2) zero1)
+ (\e1 e2 ->
+ ELet ext e1 $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
+ (weakenExpr WSink e2))
+ (ESnd ext (EVar ext (typeOf e1) IZ)))
+ | otherwise =
+ subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl ->
+ k (SEYes (SpSparse sp1) sub3)
+ (withInj minj13 $ \inj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) (EJust ext e1b))
+ (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t))))
+ (\e1 e2 ->
+ ELet ext e1 $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
+ (weakenExpr WSink e2))
+ (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ))))
+
+subenvPlus req1 req2 (SCons t env) sub1@SENo{} sub2@SEYes{} k =
+ subenvPlus req2 req1 (SCons t env) sub2 sub1 $ \sub3 minj23 minj13 pl ->
+ k sub3 minj13 minj23 (flip pl)
+
+subenvPlus req1 req2 (SCons t env) (SEYes sp1 sub1) (SEYes sp2 sub2) k =
+ subenvPlus req1 req2 env sub1 sub2 $ \sub3 minj13 minj23 pl ->
+ sparsePlusS req1 req2 t sp1 sp2 $ \sp3 mTinj13 mTinj23 plus ->
+ k (SEYes sp3 sub3)
+ (withInj2 minj13 mTinj13 $ \inj13 tinj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) (tinj13 e1b))
+ (withInj2 minj23 mTinj23 $ \inj23 tinj23 ->
+ \e2 -> eunPair e2 $ \_ e2a e2b ->
+ EPair ext (inj23 e2a) (tinj23 e2b))
+ (\e1 e2 ->
+ ELet ext e1 $
+ ELet ext (weakenExpr WSink e2) $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ)))
+ (EFst ext (EVar ext (typeOf e2) IZ)))
+ (plus
+ (ESnd ext (EVar ext (typeOf e1) (IS IZ)))
+ (ESnd ext (EVar ext (typeOf e2) IZ))))
+
+expandSubenvZeros :: D1E env0 :> env -> SList STy env0 -> SubenvS (D2E env0) contribs
+ -> Ex env (Tup contribs) -> Ex env (Tup (D2E env0))
+expandSubenvZeros _ SNil SETop _ = ENil ext
+expandSubenvZeros w (SCons t ts) (SEYes sp sub) e =
+ eunPair e $ \w1 e1 e2 ->
+ EPair ext
+ (expandSubenvZeros (w1 .> WPop w) ts sub e1)
+ (expandSparse t sp (EVar ext (d1 t) (w1 .> w @> IZ)) e2)
+expandSubenvZeros w (SCons t ts) (SENo sub) e =
+ EPair ext
+ (expandSubenvZeros (WPop w) ts sub e)
+ (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))
+
+assertSubenvEmpty :: HasCallStack => Subenv' s env env' -> env' :~: '[]
assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl
assertSubenvEmpty SETop = Refl
assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty"
@@ -407,8 +521,8 @@ accumPromote :: forall dt env sto proxy r.
-- accumulators.
-> (forall shbinds.
SList STy shbinds
- -> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum")))
- :> Append (D2AcE envPro) (D2 dt : Append shbinds (D2AcE (Select env sto "accum"))))
+ -> (dt : Append shbinds (D2AcE (Select env stoRepl "accum")))
+ :> Append (D2AcE envPro) (dt : Append shbinds (D2AcE (Select env sto "accum"))))
-- ^ A weakening that converts a computation in the
-- revised environment to one in the original environment
-- extended with some accumulators.
@@ -422,14 +536,14 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
k (storepl `DPush` (t, vid, SAccum))
envpro
prosub
- (SEYes accrevsub)
+ (SEYesR accrevsub)
(VarMap.sink1 accumMap)
(\shbinds ->
- autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr)))
+ autoWeak (#pro (d2ace envpro) &. #d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr)))
(#acc :++: (#pro :++: #d :++: #shb :++: #tl))
(#pro :++: #d :++: #shb :++: #acc :++: #tl)
.> WCopy (wf shbinds)
- .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl)))
+ .> autoWeak (#d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl)))
(#d :++: #shb :++: #acc :++: #tl)
(#acc :++: (#d :++: #shb :++: #tl)))
@@ -449,7 +563,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf ->
k (storepl `DPush` (t, vid, SAccum))
(t `SCons` envpro)
- (SEYes prosub)
+ (SEYesR prosub)
(SENo accrevsub)
(let accumMap' = VarMap.sink1 accumMap
in case fromArrayValId vid of
@@ -466,7 +580,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
-- goal: | ARE EQUAL ||
-- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
WCopy (wf shbinds)
- .> WPick @(TAccum (D2 t)) @(D2 dt : shbinds) (Const () `SCons` shbindsC)
+ .> WPick @(TAccum (D2 t)) @(dt : shbinds) (Const () `SCons` shbindsC)
(WId @(D2AcE (Select env1 stoRepl "accum"))))
-- Discrete values are left as-is, nothing to do
@@ -498,21 +612,41 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
---------------------------- RETURN TRIPLE FROM CHAD ---------------------------
-data Ret env0 sto t =
- forall shbinds tapebinds env0Merge.
+data Ret env0 sto sd t =
+ forall shbinds tapebinds contribs.
Ret (Bindings Ex (D1E env0) shbinds) -- shared binds
(Subenv shbinds tapebinds)
(Ex (Append shbinds (D1E env0)) (D1 t))
- (Subenv (Select env0 sto "merge") env0Merge)
- (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge)))
-deriving instance Show (Ret env0 sto t)
+ (SubenvS (D2E (Select env0 sto "merge")) contribs)
+ (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs))
+deriving instance Show (Ret env0 sto sd t)
-data RetPair env0 sto env shbinds tapebinds t =
- forall env0Merge.
- RetPair (Ex (Append shbinds env) (D1 t))
- (Subenv (Select env0 sto "merge") env0Merge)
- (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge)))
-deriving instance Show (RetPair env0 sto env shbinds tapebinds t)
+type data TyTyPair = MkTyTyPair Ty Ty
+
+data SingleRet env0 sto (pair :: TyTyPair) =
+ forall shbinds tapebinds.
+ SingleRet
+ (Bindings Ex (D1E env0) shbinds) -- shared binds
+ (Subenv shbinds tapebinds)
+ (RetPair env0 sto (D1E env0) shbinds tapebinds pair)
+
+-- pattern Ret1 :: forall env0 sto Bindings Ex (D1E env0) shbinds
+-- -> Subenv shbinds tapebinds
+-- -> Ex (Append shbinds (D1E env0)) (D1 t)
+-- -> SubenvS (D2E (Select env0 sto "merge")) contribs
+-- -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)
+-- -> SingleRet env0 sto (MkTyTyPair sd t)
+-- pattern Ret1 e0 subtape e1 sub e2 = SingleRet e0 subtape (RetPair e1 sub e2)
+-- {-# COMPLETE Ret1 #-}
+
+data RetPair env0 sto env shbinds tapebinds (pair :: TyTyPair) where
+ RetPair :: forall sd t contribs -- existentials
+ env0 sto env shbinds tapebinds. -- universals
+ Ex (Append shbinds env) (D1 t)
+ -> SubenvS (D2E (Select env0 sto "merge")) contribs
+ -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)
+ -> RetPair env0 sto env shbinds tapebinds (MkTyTyPair sd t)
+deriving instance Show (RetPair env0 sto env shbinds tapebinds pair)
data Rets env0 sto env list =
forall shbinds tapebinds.
@@ -521,8 +655,11 @@ data Rets env0 sto env list =
(SList (RetPair env0 sto env shbinds tapebinds) list)
deriving instance Show (Rets env0 sto env list)
+toSingleRet :: Ret env0 sto sd t -> SingleRet env0 sto (MkTyTyPair sd t)
+toSingleRet (Ret e0 subtape e1 sub e2) = SingleRet e0 subtape (RetPair e1 sub e2)
+
weakenRetPair :: SList STy shbinds -> env :> env'
- -> RetPair env0 sto env shbinds tapebinds t -> RetPair env0 sto env' shbinds tapebinds t
+ -> RetPair env0 sto env shbinds tapebinds pair -> RetPair env0 sto env' shbinds tapebinds pair
weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2
weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list
@@ -530,104 +667,137 @@ weakenRets w (Rets binds tapesub list) =
let (binds', _) = weakenBindings weakenExpr w binds
in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list)
-rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto t f.
+rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto pair f.
Descr env0 sto
-> SList f b1 -> SList f b2
-> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2
- -> RetPair env0 sto (Append b1 env) b2 tapebinds2 t
- -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) t
-rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair p sub d)
+ -> RetPair env0 sto (Append b1 env) b2 tapebinds2 pair
+ -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) pair
+rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair @sd e1 sub e2)
| Refl <- lemAppendAssoc @b2 @b1 @env =
- RetPair p sub (weakenExpr (autoWeak
- (#d (auto1 @(D2 t))
- &. #t2 (subList b2 subtape2)
- &. #t1 (subList b1 subtape1)
- &. #tl (d2ace (select SAccum descr)))
- (#d :++: (#t2 :++: #tl))
- (#d :++: ((#t2 :++: #t1) :++: #tl)))
- d)
-
-retConcat :: forall env0 sto list. Descr env0 sto -> SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list
+ RetPair e1 sub
+ (weakenExpr (autoWeak
+ (#d (auto1 @sd)
+ &. #t2 (subList b2 subtape2)
+ &. #t1 (subList b1 subtape1)
+ &. #tl (d2ace (select SAccum descr)))
+ (#d :++: (#t2 :++: #tl))
+ (#d :++: ((#t2 :++: #t1) :++: #tl)))
+ e2)
+
+retConcat :: forall env0 sto list. Descr env0 sto -> SList (SingleRet env0 sto) list -> Rets env0 sto (D1E env0) list
retConcat _ SNil = Rets BTop SETop SNil
-retConcat descr (SCons (Ret (b :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) p sub d) list)
+retConcat descr (SCons (SingleRet (e0 :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) (RetPair e1 sub e2)) list)
| Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs
- <- weakenRets (sinkWithBindings b) (retConcat descr list)
+ <- weakenRets (sinkWithBindings e0) (retConcat descr list)
, Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0)
, Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum"))
- = Rets (bconcat b binds)
+ = Rets (bconcat e0 binds)
(subenvConcat subtape subtape2)
- (SCons (RetPair (weakenExpr (sinkWithBindings binds) p)
+ (SCons (RetPair (weakenExpr (sinkWithBindings binds) e1)
sub
- (weakenExpr (WCopy (sinkWithSubenv subtape2)) d))
- (slistMap (rebaseRetPair descr (bindingsBinds b) (bindingsBinds binds)
+ (weakenExpr (WCopy (sinkWithSubenv subtape2)) e2))
+ (slistMap (rebaseRetPair descr (bindingsBinds e0) (bindingsBinds binds)
subtape subtape2)
pairs))
freezeRet :: Descr env sto
- -> Ret env sto t
+ -> Ret env sto (D2 t) t
-> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge"))))
-freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) =
+freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ _ t) =
let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0
e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2
+ tContribs = tTup (slistMap fromSMTy (subList (d2eM (select SMerge descr)) sub))
+ library = #d (auto1 @(D2 t))
+ &. #tape (subList (bindingsBinds e0) subtape)
+ &. #shbinds (bindingsBinds e0)
+ &. #d2ace (d2ace (select SAccum descr))
+ &. #tl (desD1E descr)
+ &. #contribs (SCons tContribs SNil)
in letBinds e0' $
EPair ext
(weakenExpr wInsertD2Ac e1)
- (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #tape (subList (bindingsBinds e0) subtape)
- &. #shbinds (bindingsBinds e0)
- &. #d2ace (d2ace (select SAccum descr))
- &. #tl (desD1E descr))
+ (ELet ext (weakenExpr (autoWeak library
(#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl)
(#shbinds :++: #d :++: #d2ace :++: #tl))
e2') $
- expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ))
+ expandSubenvZeros
+ (autoWeak library #tl (#contribs :++: #shbinds :++: #d :++: #d2ace :++: #tl)
+ .> wUndoSubenv (subenvD1E (selectSub SMerge descr)))
+ (select SMerge descr) sub (EVar ext tContribs IZ))
---------------------------- THE CHAD TRANSFORMATION ---------------------------
-drev :: forall env sto t.
+drev :: forall env sto sd t.
(?config :: CHADConfig)
=> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
- -> Expr ValId env t -> Ret env sto t
-drev des accumMap = \case
+ -> Sparse (D2 t) sd
+ -> Expr ValId env t -> Ret env sto sd t
+drev des _ sd | isAbsent sd =
+ \e ->
+ Ret BTop
+ SETop
+ (drevPrimal des e)
+ (subenvNone (d2e (select SMerge des)))
+ (ENil ext)
+drev _ _ SpAbsent = error "Absent should be isAbsent"
+
+drev des accumMap (SpSparse sd) =
+ \e ->
+ case drev des accumMap sd e of { Ret e0 subtape e1 sub e2 ->
+ subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ ->
+ Ret e0
+ subtape
+ e1
+ sub'
+ (emaybe (EVar ext (STMaybe (applySparse sd (d2 (typeOf e)))) IZ)
+ (inj2 (ENil ext))
+ (inj1 (weakenExpr (WCopy WSink) e2)))
+ }
+
+drev des accumMap sd = \case
EVar _ t i ->
case conv2Idx des i of
Idx2Ac accI ->
Ret BTop
SETop
(EVar ext (d1 t) (conv1Idx i))
- (subenvNone (select SMerge des))
- (EAccum ext (d2M t) SAPHere (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2M t)) (IS accI)))
+ (subenvNone (d2e (select SMerge des)))
+ (let ty = applySparse sd (d2M t)
+ in EAccum ext (d2M t) SAPHere (ENil ext) sd (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI)))
Idx2Me tupI ->
Ret BTop
SETop
(EVar ext (d1 t) (conv1Idx i))
- (subenvOnehot (select SMerge des) tupI)
- (EPair ext (ENil ext) (EVar ext (d2 t) IZ))
+ (subenvOnehot (d2e (select SMerge des)) tupI sd)
+ (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t)) IZ))
Idx2Di _ ->
Ret BTop
SETop
(EVar ext (d1 t) (conv1Idx i))
- (subenvNone (select SMerge des))
+ (subenvNone (d2e (select SMerge des)))
(ENil ext)
ELet _ (rhs :: Expr _ _ a) body
- | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des accumMap rhs
- , ChosenStorage storage <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge
- , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) body
+ | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge
+ , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body
+ , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) rhs1 subRHS rhs2 <- drev des accumMap sdBody rhs
, let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0
- , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env)
- , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) ->
- subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body ->
- let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in
+ , Refl <- lemAppendAssoc @body_shbinds @'[D1 a] @rhs_shbinds
+ , Refl <- lemAppendAssoc @body_shbinds @(D1 a : rhs_shbinds) @(D1E env)
+ , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum"))
+ ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subRHS subBody $ \subBoth _ _ plus_RHS_Body ->
+ let bodyResType = STPair (contribTupTy des subBody) (applySparse sdBody (d2 (typeOf rhs))) in
Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0')
- (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody)
+ (subenvConcat subtapeRHS subtapeBody)
(weakenExpr wbody0' body1)
subBoth
- (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #body (subList (bindingsBinds body0) subtapeBody)
+ (ELet ext (weakenExpr (autoWeak (#d (auto1 @sd)
+ &. #body (subList (bindingsBinds body0 `sappend` SCons (d1 (typeOf rhs)) SNil) subtapeBody)
&. #rhs (subList (bindingsBinds rhs0) subtapeRHS)
&. #tl (d2ace (select SAccum des)))
(#d :++: #body :++: #tl)
@@ -637,204 +807,225 @@ drev des accumMap = \case
(ELet ext (ESnd ext (EVar ext bodyResType IZ)) $
weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $
plus_RHS_Body
- (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ)
+ (EVar ext (contribTupTy des subRHS) IZ)
(EFst ext (EVar ext bodyResType (IS IZ))))
EPair _ a b
- | Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)
- <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil
- , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) ->
- subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B ->
+ | SpPair sd1 sd2 <- sd
+ , Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)
+ <- retConcat des $ toSingleRet (drev des accumMap sd1 a) `SCons` toSingleRet (drev des accumMap sd2 b) `SCons` SNil
+ , let dt = STPair (applySparse sd1 (d2 (typeOf a))) (applySparse sd2 (d2 (typeOf b))) ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B ->
Ret binds
subtape
(EPair ext a1 b1)
subBoth
- (EMaybe ext
- (zeroTup (subList (select SMerge des) subBoth))
- (ELet ext (ELet ext (EFst ext (EVar ext dt IZ))
- (weakenExpr (WCopy (wSinks' @[_,_])) a2)) $
- ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ)))
- (weakenExpr (WCopy (wSinks' @[_,_,_])) b2)) $
- plus_A_B
- (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ))
- (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ))
- (EVar ext (STMaybe (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ))
+ (ELet ext (ELet ext (EFst ext (EVar ext dt IZ))
+ (weakenExpr (WCopy WSink) a2)) $
+ ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ)))
+ (weakenExpr (WCopy (WSink .> WSink)) b2)) $
+ plus_A_B
+ (EVar ext (contribTupTy des subA) (IS IZ))
+ (EVar ext (contribTupTy des subB) IZ))
EFst _ e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e
- , STPair t1 t2 <- typeOf e ->
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair sd SpAbsent) e
+ , STPair t1 _ <- typeOf e ->
Ret e0
subtape
(EFst ext e1)
sub
- (ELet ext (EJust ext (EPair ext (EVar ext (d2 t1) IZ) (ezeroD2 t2))) $
+ (ELet ext (EPair ext (EVar ext (applySparse sd (d2 t1)) IZ) (ENil ext)) $
weakenExpr (WCopy WSink) e2)
ESnd _ e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e
- , STPair t1 t2 <- typeOf e ->
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair SpAbsent sd) e
+ , STPair _ t2 <- typeOf e ->
Ret e0
subtape
(ESnd ext e1)
sub
- (ELet ext (EJust ext (EPair ext (ezeroD2 t1) (EVar ext (d2 t2) IZ))) $
+ (ELet ext (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t2)) IZ)) $
weakenExpr (WCopy WSink) e2)
- ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext)
+ -- Don't need to handle ENil, because its cotangent is always absent!
+ -- ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (d2e (select SMerge des))) (ENil ext)
EInl _ t2 e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e ->
+ | SpLEither sd1 sd2 <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap sd1 e ->
+ subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ ->
Ret e0
subtape
(EInl ext (d1 t2) e1)
- sub
+ sub'
(ELCase ext
- (EVar ext (STLEither (d2 (typeOf e)) (d2 t2)) IZ)
- (zeroTup (subList (select SMerge des) sub))
- (weakenExpr (WCopy WSink) e2)
- (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr"))
+ (EVar ext (STLEither (applySparse sd1 (d2 (typeOf e))) (applySparse sd2 (d2 t2))) IZ)
+ (inj2 $ ENil ext)
+ (inj1 $ weakenExpr (WCopy WSink) e2)
+ (EError ext (contribTupTy des sub') "inl<-dinr"))
EInr _ t1 e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e ->
+ | SpLEither sd1 sd2 <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap sd2 e ->
+ subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ ->
Ret e0
subtape
(EInr ext (d1 t1) e1)
- sub
+ sub'
(ELCase ext
- (EVar ext (STLEither (d2 t1) (d2 (typeOf e))) IZ)
- (zeroTup (subList (select SMerge des) sub))
- (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl")
- (weakenExpr (WCopy WSink) e2))
+ (EVar ext (STLEither (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 (typeOf e)))) IZ)
+ (inj2 $ ENil ext)
+ (EError ext (contribTupTy des sub') "inr<-dinl")
+ (inj1 $ weakenExpr (WCopy WSink) e2))
ECase _ e (a :: Expr _ _ t) b
- | STEither t1 t2 <- typeOf e
- , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap e
+ | STEither (t1 :: STy a) (t2 :: STy b) <- typeOf e
, ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge
, ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge
, let (bindids1, bindids2) = validSplitEither (extOf e)
- , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des accumMap t1 storage1 bindids1 a
- , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des accumMap t2 storage2 bindids2 b
+ , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA sd1 a2
+ <- drevScoped des accumMap t1 storage1 bindids1 sd a
+ , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB sd2 b2
+ <- drevScoped des accumMap t2 storage2 bindids2 sd b
+ , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap (SpLEither sd1 sd2) e
, Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
, Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
- , let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA)
- , let tapeB = tapeTy (subList (bindingsBinds b0) subtapeB)
- , let collectA = bindingsCollectTape a0 subtapeA
- , let collectB = bindingsCollectTape b0 subtapeB
+ , let subtapeListA = subList (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA
+ , let subtapeListB = subList (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB
+ , let tapeA = tapeTy subtapeListA
+ , let tapeB = tapeTy subtapeListB
+ , let collectA = bindingsCollectTape @_ @_ @(Append rhs_a_binds (D1 a : Append e_binds (D1E env)))
+ (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA
+ , let collectB = bindingsCollectTape @_ @_ @(Append rhs_b_binds (D1 b : Append e_binds (D1E env)))
+ (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB
, (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB)
, let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0
, let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0
+ , Refl <- lemAppendNil @(Append rhs_a_binds '[D1 a])
+ , Refl <- lemAppendNil @(Append rhs_b_binds '[D1 b])
+ , Refl <- lemAppendAssoc @rhs_a_binds @'[D1 a] @(D1E env)
+ , Refl <- lemAppendAssoc @rhs_b_binds @'[D1 b] @(D1E env)
+ , let wa0'' = wa0' .> wCopies (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) (WClosed @(D1E env))
+ , let wb0'' = wb0' .> wCopies (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) (WClosed @(D1E env))
->
- subenvPlus (select SMerge des) subA subB $ \subAB sAB_A sAB_B _ ->
- subenvPlus (select SMerge des) subAB subE $ \subOut _ _ plus_AB_E ->
- let tCaseRet = STPair (tTup (d2e (subList (select SMerge des) subAB))) (STLEither (d2 t1) (d2 t2)) in
+ subenvPlus ST ST (d2eM (select SMerge des)) subA subB $ \subAB (Inj sAB_A) (Inj sAB_B) _ ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subAB subE $ \subOut _ _ plus_AB_E ->
Ret (e0 `BPush`
(tPrimal,
ECase ext e1
- (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0'))))
- (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0'))))))
- (SEYes subtapeE)
+ (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0''))))
+ (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0''))))))
+ (SEYesR subtapeE)
(EFst ext (EVar ext tPrimal IZ))
subOut
- (ELet ext
+ (elet
(ECase ext (ESnd ext (EVar ext tPrimal (IS IZ)))
- (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds a0) subtapeA) IZ
+ (let (rebinds, prerebinds) = reconstructBindings subtapeListA IZ
in letBinds rebinds $
ELet ext
- (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_tape : D2 t : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend (subList (bindingsBinds a0) subtapeA) prerebinds) @> IS IZ)) $
- ELet ext
- (weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #ta0 (subList (bindingsBinds a0) subtapeA)
+ (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_a_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListA prerebinds) @> IS IZ)) $
+ elet
+ (weakenExpr (autoWeak (#d (auto1 @sd)
+ &. #ta0 subtapeListA
&. #prea0 prerebinds
- &. #recon (tapeA `SCons` d2 (typeOf a) `SCons` SNil)
+ &. #recon (tapeA `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil)
&. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE)
&. #tl (d2ace (select SAccum des)))
(#d :++: #ta0 :++: #tl)
(#d :++: (#ta0 :++: #prea0) :++: #recon :++: #binds :++: #tl))
a2) $
- EPair ext
- (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $
- EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ))
- (ELInl ext (d2 t2)
- (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ))))
- (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds b0) subtapeB) IZ
+ EPair ext (sAB_A $ EFst ext (evar IZ))
+ (ELInl ext (applySparse sd2 (d2 t2)) (ESnd ext (evar IZ))))
+ (let (rebinds, prerebinds) = reconstructBindings subtapeListB IZ
in letBinds rebinds $
ELet ext
- (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_tape : D2 t : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend (subList (bindingsBinds b0) subtapeB) prerebinds) @> IS IZ)) $
- ELet ext
- (weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #tb0 (subList (bindingsBinds b0) subtapeB)
+ (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_b_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListB prerebinds) @> IS IZ)) $
+ elet
+ (weakenExpr (autoWeak (#d (auto1 @sd)
+ &. #tb0 subtapeListB
&. #preb0 prerebinds
- &. #recon (tapeB `SCons` d2 (typeOf a) `SCons` SNil)
+ &. #recon (tapeB `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil)
&. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE)
&. #tl (d2ace (select SAccum des)))
(#d :++: #tb0 :++: #tl)
(#d :++: (#tb0 :++: #preb0) :++: #recon :++: #binds :++: #tl))
b2) $
- EPair ext
- (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $
- EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))
- (ELInr ext (d2 t1)
- (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))))) $
- ELet ext
- (ELet ext (ESnd ext (EVar ext tCaseRet IZ)) $
- weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $
+ EPair ext (sAB_B $ EFst ext (evar IZ))
+ (ELInr ext (applySparse sd1 (d2 t1)) (ESnd ext (evar IZ))))) $
plus_AB_E
- (EFst ext (EVar ext tCaseRet (IS IZ)))
- (EVar ext (tTup (d2e (subList (select SMerge des) subE))) IZ))
+ (EFst ext (evar IZ))
+ (ELet ext (ESnd ext (evar IZ)) $
+ weakenExpr (WCopy (wSinks' @[_,_,_])) e2))
EConst _ t val ->
Ret BTop
SETop
(EConst ext t val)
- (subenvNone (select SMerge des))
+ (subenvNone (d2e (select SMerge des)))
(ENil ext)
EOp _ op e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e ->
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap (spDense (d2M (opt1 op))) e ->
case d2op op of
Linear d2opfun ->
Ret e0
subtape
(d1op op e1)
sub
- (ELet ext (d2opfun (EVar ext (d2 (opt2 op)) IZ))
+ (ELet ext (d2opfun (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ)))
(weakenExpr (WCopy WSink) e2))
Nonlinear d2opfun ->
Ret (e0 `BPush` (d1 (typeOf e), e1))
- (SEYes subtape)
+ (SEYesR subtape)
(d1op op $ EVar ext (d1 (typeOf e)) IZ)
sub
(ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ))
- (EVar ext (d2 (opt2 op)) IZ))
+ (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ)))
(weakenExpr (WCopy (wSinks' @[_,_])) e2))
- ECustom _ _ _ storety _ pr du a b
+ ECustom _ _ tb storety srce pr du a b
-- allowed to ignore a2 because 'a' is the part of the input that is inactive
- | Rets binds subtape (RetPair a1 _ _ `SCons` RetPair b1 bsub b2 `SCons` SNil)
- <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil ->
- Ret (binds `BPush` (typeOf a1, a1)
- `BPush` (typeOf b1, weakenExpr WSink b1)
- `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr))
- `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ)))
- (SEYes (SENo (SENo (SENo subtape))))
- (EFst ext (EVar ext (typeOf pr) (IS IZ)))
- bsub
- (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $
- weakenExpr (WCopy (WSink .> WSink)) b2)
-
- -- TODO: compute primal in direct form here instead of taking the redundantly inefficient CHAD primal
+ | Ret b0 bsubtape b1 bsub b2 <- drev des accumMap (spDense (d2M tb)) b ->
+ case isDense (d2M (typeOf srce)) sd of
+ Just Refl ->
+ Ret (b0 `BPush` (d1 (typeOf a), weakenExpr (sinkWithBindings b0) (drevPrimal des a))
+ `BPush` (typeOf b1, weakenExpr WSink b1)
+ `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr))
+ `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ)))
+ (SEYesR (SENo (SENo (SENo bsubtape))))
+ (EFst ext (EVar ext (typeOf pr) (IS IZ)))
+ bsub
+ (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $
+ weakenExpr (WCopy (WSink .> WSink)) b2)
+
+ Nothing ->
+ Ret (b0 `BPush` (d1 (typeOf a), weakenExpr (sinkWithBindings b0) (drevPrimal des a))
+ `BPush` (typeOf b1, weakenExpr WSink b1)
+ `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)))
+ (SEYesR (SENo (SENo bsubtape)))
+ (EFst ext (EVar ext (typeOf pr) IZ))
+ bsub
+ (ELet ext (ESnd ext (EVar ext (typeOf pr) (IS IZ))) $ -- tape
+ ELet ext (expandSparse (typeOf srce) sd -- expanded incoming cotangent
+ (EFst ext (EVar ext (typeOf pr) (IS (IS IZ))))
+ (EVar ext (applySparse sd (d2 (typeOf srce))) (IS IZ))) $
+ ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) b2)
+
ERecompute _ e ->
deleteUnused (descrList des) (occCountAll e) $ \usedSub ->
let smallE = unsafeWeakenWithSubenv usedSub e in
subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
- case drev usedDes (VarMap.subMap subAccumUsed accumMap) smallE of { Ret e0 subtape e1 sub e2 ->
+ case drev usedDes (VarMap.subMap subAccumUsed accumMap) sd smallE of { Ret e0 subtape _ sub e2 ->
+ let subMergeUsed' = subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E subMergeUsed) in
Ret (collectBindings (desD1E des) subD1eUsed)
(subenvAll (desD1E usedDes))
- (weakenExpr (wRaiseAbove (desD1E usedDes) (desD1E des)) $ letBinds e0 e1)
- (subenvCompose subMergeUsed sub)
+ (weakenExpr (wSinks (desD1E usedDes)) $ drevPrimal des e)
+ (subenvCompose subMergeUsed' sub)
(letBinds (fst (weakenBindings weakenExpr (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $
weakenExpr
- (autoWeak (#d (auto1 @(D2 t))
+ (autoWeak (#d (auto1 @sd)
&. #shbinds (bindingsBinds e0)
&. #tape (subList (bindingsBinds e0) subtape)
&. #d1env (desD1E usedDes)
@@ -849,128 +1040,130 @@ drev des accumMap = \case
Ret BTop
SETop
(EError ext (d1 t) s)
- (subenvNone (select SMerge des))
+ (subenvNone (d2e (select SMerge des)))
(ENil ext)
EConstArr _ n t val ->
Ret BTop
SETop
(EConstArr ext n t val)
- (subenvNone (select SMerge des))
+ (subenvNone (d2e (select SMerge des)))
(ENil ext)
EBuild _ (ndim :: SNat ndim) she (orige :: Expr _ _ eltty)
- | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des accumMap she -- allowed to ignore she2 here because she has a discrete result
+ | SpArr @_ @sdElt sdElt <- sd
, let eltty = typeOf orige
, shty :: STy shty <- tTup (sreplicate ndim tIx)
, Refl <- indexTupD1Id ndim ->
deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') ->
- let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in
- subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
- accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->
+ let e = unsafeWeakenWithSubenv (SEYesR usedSub) orige in
+ subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed ->
+ accumPromote sdElt usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->
let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in
- case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 ->
+ case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro sdElt e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 ->
case assertSubenvEmpty sub of { Refl ->
+ case lemAppendNil @e_binds of { Refl ->
let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in
- let collectexpr = bindingsCollectTape e0 subtapeE in
- Ret (BTop `BPush` (shty, letBinds she0 she1)
- `BPush` (STArr ndim (STPair (d1 eltty) tapety)
- ,EBuild ext ndim
- (EVar ext shty IZ)
- (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil)
- &. #sh (shty `SCons` SNil)
- &. #d1env (desD1E des)
- &. #d1env' (desD1E usedDes))
- (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
- (#ix :++: #sh :++: #d1env))
- e0)) $
- let w = autoWeak (#ix (shty `SCons` SNil)
- &. #sh (shty `SCons` SNil)
- &. #e0 (bindingsBinds e0)
- &. #d1env (desD1E des)
- &. #d1env' (desD1E usedDes))
- (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
- (#e0 :++: #ix :++: #sh :++: #d1env)
- in EPair ext (weakenExpr w e1) (collectexpr w)))
- `BPush` (STArr ndim tapety, emap (ESnd ext (EVar ext (STPair (d1 eltty) tapety) IZ))
- (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ)))
- (SEYes (SENo (SEYes SETop)))
- (emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ))
- (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ)))
- (subenvCompose subMergeUsed proSub)
- (let sinkOverEnvPro = wSinks @(TArr ndim (D2 eltty) : D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in
- EMaybe ext
- (zeroTup envPro)
- (ESnd ext $
- uninvertTup (d2e envPro) (STArr ndim STNil) $
- makeAccumulators @_ @_ @(TArr ndim TNil) envPro $
- EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS (IS IZ)))) $
- -- the cotangent for this element
- ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ))
- (EVar ext shty IZ)) $
- -- the tape for this element
- ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS (IS IZ)))
- (EVar ext shty (IS IZ))) $
- let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ
- in letBinds rebinds $
- weakenExpr (autoWeak (#d (auto1 @(D2 eltty))
- &. #pro (d2ace envPro)
- &. #etape (subList (bindingsBinds e0) subtapeE)
- &. #prerebinds prerebinds
- &. #tape (auto1 @(Tape e_tape))
- &. #ix (auto1 @shty)
- &. #darr (auto1 @(TArr ndim (D2 eltty)))
- &. #mdarr (auto1 @(TMaybe (TArr ndim (D2 eltty))))
- &. #tapearr (auto1 @(TArr ndim (Tape e_tape)))
- &. #sh (auto1 @shty)
- &. #d2acUsed (d2ace (select SAccum usedDes))
- &. #d2acEnv (d2ace (select SAccum des)))
- (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
- ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #mdarr :++: #tapearr :++: #sh :++: #d2acEnv)
- .> wPro (subList (bindingsBinds e0) subtapeE))
- e2)
- (EVar ext (d2 (STArr ndim eltty)) IZ))
- }}
+ let collectexpr = bindingsCollectTape (bindingsBinds e0) subtapeE in
+ let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in
+ let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in
+ Ret (mergePrimalBindings
+ `BPush` (shty, weakenExpr (wSinks (d1e envPro)) (drevPrimal des she))
+ `BPush` (STArr ndim (STPair (d1 eltty) tapety)
+ ,EBuild ext ndim
+ (EVar ext shty IZ)
+ (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil)
+ &. #sh (shty `SCons` SNil)
+ &. #propr (d1e envPro)
+ &. #d1env (desD1E des)
+ &. #d1env' (desD1E usedDes))
+ (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
+ (#ix :++: #sh :++: #propr :++: #d1env))
+ e0)) $
+ let w = autoWeak (#ix (shty `SCons` SNil)
+ &. #sh (shty `SCons` SNil)
+ &. #e0 (bindingsBinds e0)
+ &. #propr (d1e envPro)
+ &. #d1env (desD1E des)
+ &. #d1env' (desD1E usedDes))
+ (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
+ (#e0 :++: #ix :++: #sh :++: #propr :++: #d1env)
+ w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env'))
+ in EPair ext (weakenExpr w e1) (collectexpr w')))
+ `BPush` (STArr ndim tapety, emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ)))
+ (SEYesR (SENo (SEYesR (subenvAll (d1e envPro)))))
+ (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ)))
+ (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub)))
+ (let sinkOverEnvPro = wSinks @(sd : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace envPro) in
+ ESnd ext $
+ uninvertTup (d2e envPro) (STArr ndim STNil) $
+ makeAccumulators @_ @_ @(TArr ndim TNil) (WSink .> WSink .> WSink .> wRaiseAbove (d1e envPro) (d2ace (select SAccum des))) envPro $
+ EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $
+ -- the cotangent for this element
+ ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ))
+ (EVar ext shty IZ)) $
+ -- the tape for this element
+ ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
+ (EVar ext shty (IS IZ))) $
+ let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ
+ in letBinds rebinds $
+ weakenExpr (autoWeak (#d (auto1 @sdElt)
+ &. #pro (d2ace envPro)
+ &. #etape (subList (bindingsBinds e0) subtapeE)
+ &. #prerebinds prerebinds
+ &. #tape (auto1 @(Tape e_tape))
+ &. #ix (auto1 @shty)
+ &. #darr (auto1 @(TArr ndim sdElt))
+ &. #tapearr (auto1 @(TArr ndim (Tape e_tape)))
+ &. #sh (auto1 @shty)
+ &. #propr (d1e envPro)
+ &. #d2acUsed (d2ace (select SAccum usedDes))
+ &. #d2acEnv (d2ace (select SAccum des)))
+ (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
+ ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #propr :++: #d2acEnv)
+ .> wPro (subList (bindingsBinds e0) subtapeE))
+ e2)
+ }}}
EUnit _ e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e ->
+ | SpArr sdElt <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap sdElt e ->
Ret e0
subtape
(EUnit ext e1)
sub
- (EMaybe ext
- (zeroTup (subList (select SMerge des) sub))
- (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $
- weakenExpr (WCopy (WSink .> WSink)) e2)
- (EVar ext (STMaybe (STArr SZ (d2 (typeOf e)))) IZ))
+ (ELet ext (EIdx0 ext (EVar ext (STArr SZ (applySparse sdElt (d2 (typeOf e)))) IZ)) $
+ weakenExpr (WCopy WSink) e2)
EReplicate1Inner _ en e
- -- We're allowed to ignore en2 here because the output of 'ei' is discrete.
- | Rets binds subtape (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil)
- <- retConcat des $ drev des accumMap en `SCons` drev des accumMap e `SCons` SNil
+ -- We're allowed to differentiate 'en' as primal-only here because its output is discrete.
+ | SpArr sdElt <- sd
, let STArr ndim eltty = typeOf e ->
- Ret binds
- subtape
- (EReplicate1Inner ext en1 e1)
- sub
- (EMaybe ext
- (zeroTup (subList (select SMerge des) sub))
- (ELet ext (EJust ext (EFold1Inner ext Commut
- (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
- (ezeroD2 eltty)
- (EVar ext (STArr (SS ndim) (d2 eltty)) IZ))) $
- weakenExpr (WCopy (WSink .> WSink)) e2)
- (EVar ext (d2 (STArr (SS ndim) eltty)) IZ))
+ -- This pessimistic sparsity union is because the array might have been empty, in which case we need to generate a zero.
+ sparsePlusS ST ST (d2M eltty) sdElt SpAbsent $ \sdElt' (Inj inj1) (Inj inj2) _ ->
+ case drev des accumMap (SpArr sdElt') e of { Ret binds subtape e1 sub e2 ->
+ Ret binds
+ subtape
+ (EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1)
+ sub
+ (ELet ext (EFold1Inner ext Commut
+ (sparsePlus (d2M eltty) sdElt'
+ (EVar ext (applySparse sdElt' (d2 eltty)) (IS IZ))
+ (EVar ext (applySparse sdElt' (d2 eltty)) IZ))
+ (inj2 (ENil ext))
+ (emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $
+ weakenExpr (WCopy WSink) e2)
+ }
EIdx0 _ e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd) e
, STArr _ t <- typeOf e ->
Ret e0
subtape
(EIdx0 ext e1)
sub
- (ELet ext (EJust ext (EUnit ext (EVar ext (d2 t) IZ))) $
- weakenExpr (WCopy WSink) e2)
+ (ELet ext (EUnit ext (EVar ext (applySparse sd (d2 t)) IZ)) $
+ weakenExpr (WCopy WSink) e2)
EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead"
{-
@@ -981,7 +1174,7 @@ drev des accumMap = \case
, STArr (SS n) eltty <- typeOf e ->
Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1)
`BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ)))
- (SEYes (SENo subtape))
+ (SEYesR (SENo subtape))
(EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ))
(weakenExpr (WSink .> WSink) ei1))
sub
@@ -992,57 +1185,58 @@ drev des accumMap = \case
-}
EIdx _ e ei
- -- We're allowed to ignore ei2 here because the output of 'ei' is discrete.
- | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)
- <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil
- , STArr n eltty <- typeOf e
+ -- We're allowed to differentiate ei as primal because its output is discrete.
+ | STArr n eltty <- typeOf e
, Refl <- indexTupD1Id n
- , Refl <- lemZeroInfoD2 eltty
- , let tIxN = tTup (sreplicate n tIx) ->
- Ret (binds `BPush` (STArr n (d1 eltty), e1)
- `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ))
- `BPush` (tIxN, weakenExpr (WSink .> WSink) ei1))
- (SEYes (SEYes (SENo subtape)))
- (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ)))
- (EVar ext (tTup (sreplicate n tIx)) IZ))
- sub
- (ELet ext (EOneHot ext (d2M (STArr n eltty)) (SAPJust (SAPArrIdx SAPHere))
- (EPair ext (EPair ext (EVar ext tIxN (IS IZ))
- (EBuild ext n (EVar ext tIxN (IS (IS IZ))) (ENil ext)))
- (ENil ext))
- (EVar ext (d2 eltty) IZ)) $
- weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+ , let tIxN = tTup (sreplicate n tIx) ->
+ sparsePlusS ST ST (d2M eltty) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ ->
+ case drev des accumMap (SpArr sd') e of { Ret binds subtape e1 sub e2 ->
+ Ret (binds `BPush` (STArr n (d1 eltty), e1)
+ `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ))
+ `BPush` (tIxN, weakenExpr (WSink .> WSink .> wSinks (bindingsBinds binds)) (drevPrimal des ei)))
+ (SEYesR (SEYesR (SENo subtape)))
+ (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ)))
+ (EVar ext (tTup (sreplicate n tIx)) IZ))
+ sub
+ (ELet ext
+ (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty)))
+ (SAPArrIdx SAPHere)
+ (EPair ext
+ (EPair ext (EVar ext tIxN (IS IZ))
+ (EBuild ext n (EVar ext tIxN (IS (IS IZ))) $
+ makeZeroInfo (applySparse sd' (d2M eltty)) (inj2 (ENil ext))))
+ (ENil ext))
+ (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+ }
EShape _ e
- -- Allowed to ignore e2 here because the output of EShape is discrete,
- -- hence we'd be passing a zero cotangent to e2 anyway.
- | Ret e0 subtape e1 _ _ <- drev des accumMap e
- , STArr n _ <- typeOf e
+ -- Allowed to differentiate e as primal because the output of EShape is
+ -- discrete, hence we'd be passing a zero cotangent to e anyway.
+ | STArr n _ <- typeOf e
, Refl <- indexTupD1Id n ->
- Ret e0
- subtape
- (EShape ext e1)
- (subenvNone (select SMerge des))
+ Ret BTop
+ SETop
+ (EShape ext (drevPrimal des e))
+ (subenvNone (d2eM (select SMerge des)))
(ENil ext)
ESum1Inner _ e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e
+ | SpArr sd' <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e
, STArr (SS n) t <- typeOf e ->
Ret (e0 `BPush` (STArr (SS n) t, e1)
`BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ)))
- (SEYes (SENo subtape))
+ (SEYesR (SENo subtape))
(ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ)))
sub
- (EMaybe ext
- (zeroTup (subList (select SMerge des) sub))
- (ELet ext (EJust ext (EReplicate1Inner ext
- (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS (IS IZ))))
- (EVar ext (STArr n (d2 t)) IZ))) $
- weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
- (EVar ext (d2 (STArr n t)) IZ))
+ (ELet ext (EReplicate1Inner ext
+ (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ)))
+ (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
- EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e
- EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e
+ EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e
+ EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e
-- These should be the next to be implemented, I think
EFold1Inner{} -> err_unsupported "EFold1Inner"
@@ -1056,8 +1250,8 @@ drev des accumMap = \case
ELCase{} -> err_unsupported "ELCase"
EWith{} -> err_accum
- EAccum{} -> err_accum
EZero{} -> err_monoid
+ EDeepZero{} -> err_monoid
EPlus{} -> err_monoid
EOneHot{} -> err_monoid
@@ -1066,94 +1260,116 @@ drev des accumMap = \case
err_monoid = error "Monoid operations unsupported in the source program"
err_unsupported s = error $ "CHAD: unsupported " ++ s
- deriv_extremum :: ScalIsNumeric t' ~ True
- => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t')))
- -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t'))
- deriv_extremum extremum e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e
- , at@(STArr (SS n) t@(STScal st)) <- typeOf e
- , let at' = STArr n t
- , let tIxN = tTup (sreplicate (SS n) tIx) =
- Ret (e0 `BPush` (at, e1)
- `BPush` (at', extremum (EVar ext at IZ)))
- (SEYes (SEYes subtape))
- (EVar ext at' IZ)
- sub
- (EMaybe ext
- (zeroTup (subList (select SMerge des) sub))
- (ELet ext (EJust ext
- (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS (IS IZ))))) $
- eif (EOp ext (OEq st) (EPair ext
- (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ))
- (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ)))))
- (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ)))
- (ezeroD2 t))) $
- weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2)
- (EVar ext (d2 at') IZ))
+ contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs)
+ contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub))
+
+deriv_extremum :: (?config :: CHADConfig, ScalIsNumeric t ~ True)
+ => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t)))
+ -> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
+ -> Sparse (D2s t) sd
+ -> Expr ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t))
+deriv_extremum extremum des accumMap sd e
+ | at@(STArr (SS n) t@(STScal st)) <- typeOf e
+ , let at' = STArr n t
+ , let tIxN = tTup (sreplicate (SS n) tIx) =
+ sparsePlusS ST ST (d2M t) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ ->
+ case drev des accumMap (SpArr sd') e of { Ret e0 subtape e1 sub e2 ->
+ Ret (e0 `BPush` (at, e1)
+ `BPush` (at', extremum (EVar ext at IZ)))
+ (SEYesR (SEYesR subtape))
+ (EVar ext at' IZ)
+ sub
+ (ELet ext
+ (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $
+ eif (EOp ext (OEq st) (EPair ext
+ (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ))
+ (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ)))))
+ (inj1 $ EIdx ext (EVar ext (STArr n (applySparse sd (d2 t))) (IS IZ)) (EFst ext (EVar ext tIxN IZ)))
+ (inj2 (ENil ext))) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+ }
data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s)
-data RetScoped env0 sto a s t =
- forall shbinds tapebinds env0Merge.
+data RetScoped env0 sto a s sd t =
+ forall shbinds tapebinds contribs sa.
RetScoped
(Bindings Ex (D1E (a : env0)) shbinds) -- shared binds
- (Subenv shbinds tapebinds)
+ (Subenv (Append shbinds '[D1 a]) tapebinds)
(Ex (Append shbinds (D1E (a : env0))) (D1 t))
- (Subenv (Select env0 sto "merge") env0Merge)
+ (SubenvS (D2E (Select env0 sto "merge")) contribs)
-- ^ merge contributions to the _enclosing_ merge environment
- (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum")))
- (If (s == "discr") (Tup (D2E env0Merge))
- (TPair (Tup (D2E env0Merge)) (D2 a))))
+ (Sparse (D2 a) sa)
+ -- ^ contribution to the argument
+ (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum")))
+ (If (s == "discr") (Tup contribs)
+ (TPair (Tup contribs) sa)))
-- ^ the merge contributions, plus the cotangent to the argument
-- (if there is any)
-deriving instance Show (RetScoped env0 sto a s t)
+deriving instance Show (RetScoped env0 sto a s sd t)
-drevScoped :: forall a s env sto t.
+drevScoped :: forall a s env sto sd t.
(?config :: CHADConfig)
=> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
-> STy a -> Storage s -> Maybe (ValId a)
+ -> Sparse (D2 t) sd
-> Expr ValId (a : env) t
- -> RetScoped env sto a s t
-drevScoped des accumMap argty argsto argids expr = case argsto of
+ -> RetScoped env sto a s sd t
+drevScoped des accumMap argty argsto argids sd expr = case argsto of
SMerge
- | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr ->
+ | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr
+ , Refl <- lemAppendNil @tapebinds ->
case sub of
- SEYes sub' -> RetScoped e0 subtape e1 sub' e2
- SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (ezeroD2 argty))
+ SEYes sp sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' sp e2
+ SENo sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' SpAbsent (EPair ext e2 (ENil ext))
SAccum
- | Just (VIArr i _) <- argids
+ | chcSmartWith ?config
+ , Just (VIArr i _) <- argids
, Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap
, Just Refl <- testEquality foundTy (STAccum (d2M argty))
- , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) expr ->
- RetScoped e0 subtape e1 sub $
+ , Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) sd expr
+ , Refl <- lemAppendNil @tapebinds ->
+ -- Our contribution to the binding's cotangent _here_ is zero (absent),
+ -- because we're contributing to an earlier binding of the same value
+ -- instead.
+ RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent $
let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in
ELet ext (EVar ext (STAccum (d2M argty)) (WSink .> wtapebinds @> idx)) $
- weakenExpr (autoWeak (#d (auto1 @(D2 t))
+ weakenExpr (autoWeak (#d (auto1 @sd)
&. #body (subList (bindingsBinds e0) subtape)
&. #ac (auto1 @(TAccum (D2 a)))
&. #tl (d2ace (select SAccum des)))
(#d :++: #body :++: #ac :++: #tl)
(#ac :++: #d :++: #body :++: #tl))
- -- Our contribution to the binding's cotangent _here_ is
- -- zero, because we're contributing to an earlier binding
- -- of the same value instead.
- (EPair ext e2 (ezeroD2 argty))
+ (EPair ext e2 (ENil ext))
| let accumMap' = case argids of
Just (VIArr i _) -> VarMap.insert i (STAccum (d2M argty)) IZ (VarMap.sink1 accumMap)
_ -> VarMap.sink1 accumMap
- , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' expr ->
- RetScoped e0 subtape e1 sub $
- EWith ext (d2M argty) (ezeroD2 argty) $
- weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #body (subList (bindingsBinds e0) subtape)
- &. #ac (auto1 @(TAccum (D2 a)))
- &. #tl (d2ace (select SAccum des)))
+ , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' sd expr ->
+ let library = #d (auto1 @sd)
+ &. #p (auto1 @(D1 a))
+ &. #body (subList (bindingsBinds e0) subtape)
+ &. #ac (auto1 @(TAccum (D2 a)))
+ &. #tl (d2ace (select SAccum des))
+ in
+ RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub (spDense (d2M argty)) $
+ let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in
+ EWith ext (d2M argty) (EDeepZero ext (d2M argty) (d2deepZeroInfo argty (EVar ext (d1 argty) primalIdx))) $
+ weakenExpr (autoWeak library
(#d :++: #body :++: #ac :++: #tl)
- (#ac :++: #d :++: #body :++: #tl))
+ (#ac :++: #d :++: (#body :++: #p) :++: #tl))
e2
SDiscr
- | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr ->
- RetScoped e0 subtape e1 sub e2
+ | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr
+ , Refl <- lemAppendNil @tapebinds ->
+ RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2
+
+-- TODO: proper primal-only transform that doesn't depend on D1 = Id
+drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t)
+drevPrimal des e
+ | Refl <- d1Identity (typeOf e)
+ , Refl <- d1eIdentity (descrList des)
+ = mapExt (const ext) e