aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs183
1 files changed, 100 insertions, 83 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 72ce36d..298d964 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1043,12 +1043,12 @@ drev des accumMap sd = \case
(subenvNone (d2e (select SMerge des)))
(ENil ext)
- EBuild _ (ndim :: SNat ndim) she (orige :: Expr _ _ eltty)
+ EBuild _ (ndim :: SNat ndim) she (ef :: Expr _ _ eltty)
| SpArr @_ @sdElt sdElt <- sd
- , let eltty = typeOf orige
+ , let eltty = typeOf ef
, shty :: STy shty <- tTup (sreplicate ndim tIx)
, Refl <- indexTupD1Id ndim ->
- drevLambda des accumMap (shty, SDiscr) sdElt orige $ \(provars :: SList _ envPro) esub proPrimalBinds e0 e1 (e1tape :: Ex _ e_tape) _ wrapAccum e2 ->
+ drevLambda des accumMap (shty, SDiscr) sdElt ef $ \(provars :: SList _ envPro) esub proPrimalBinds e0 e1 (e1tape :: Ex _ e_tape) _ wrapAccum e2 ->
let library = #ix (shty `SCons` SNil)
&. #e0 (bindingsBinds e0)
&. #propr (d1e provars)
@@ -1060,15 +1060,11 @@ drev des accumMap sd = \case
&. #darr (auto1 @(TArr ndim sdElt))
&. #tapearr (auto1 @(TArr ndim e_tape)) in
Ret (proPrimalBinds
- `bpush` EBuild ext ndim
- (weakenExpr (wSinks (d1e provars)) (drevPrimal des she))
- (letBinds (fst (weakenBindingsE (autoWeak library
- (#ix :++: #d1env)
- (#ix :++: #propr :++: #d1env))
- e0)) $
- weakenExpr (autoWeak library (#e0 :++: #ix :++: #d1env)
- (#e0 :++: #ix :++: #propr :++: #d1env))
- (EPair ext e1 e1tape))
+ `bpush` weakenExpr (wSinks (d1e provars))
+ (EBuild ext ndim
+ (drevPrimal des she)
+ (letBinds e0 $
+ EPair ext e1 e1tape))
`bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) IZ))
(SEYesR (SENo (subenvAll (d1e provars))))
(emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) (IS IZ)))
@@ -1077,37 +1073,77 @@ drev des accumMap sd = \case
ESnd ext $
wrapAccum (WSink .> WSink .> wRaiseAbove (d1e provars) (d2ace (select SAccum des))) $
EBuild ext ndim (EShape ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (sinkOverEnvPro @> IZ))) $
- -- the tape for this element
- ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> sinkOverEnvPro @> IS IZ))
- (EVar ext shty IZ)) $
-- the cotangent for this element
- ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> WSink .> sinkOverEnvPro @> IZ))
+ ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ))
+ (EVar ext shty IZ)) $
+ -- the tape for this element
+ ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
(EVar ext shty (IS IZ))) $
- weakenExpr (autoWeak library (#d :++: #tape :++: #pro :++: #d2acEnv)
- (#d :++: #tape :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv))
+ weakenExpr (autoWeak library (#tape :++: #d :++: #pro :++: #d2acEnv)
+ (#tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv))
e2)
- EMap{} -> undefined
+ EMap _ ef (earr :: Expr _ _ (TArr n a))
+ | SpArr sdElt <- sd
+ , let STArr ndim t1 = typeOf earr
+ t2 = typeOf ef ->
+ drevLambda des accumMap (t1, SMerge) sdElt ef $ \provars efsub proPrimalBinds ef0 ef1 ef1tape spEf wrapAccum ef2 ->
+ case drev des accumMap (SpArr spEf) earr of { Ret ea0 easubtape ea1 easub ea2 ->
+ let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings ea0) proPrimalBinds
+ ttape = typeOf ef1tape
+ library = #d1env (desD1E des)
+ &. #a0 (bindingsBinds ea0)
+ &. #atapebinds (subList (bindingsBinds ea0) easubtape)
+ &. #propr (d1e provars)
+ &. #x (d1 t1 `SCons` SNil)
+ &. #parr (STArr ndim (d1 t1) `SCons` SNil)
+ &. #tapearr (STArr ndim ttape `SCons` SNil)
+ &. #darr (STArr ndim (applySparse sdElt (d2 t2)) `SCons` SNil)
+ &. #dy (applySparse sdElt (d2 t2) `SCons` SNil)
+ &. #tape (ttape `SCons` SNil)
+ &. #dytape (STPair (applySparse sdElt (d2 t2)) ttape `SCons` SNil)
+ &. #d2acEnv (d2ace (select SAccum des))
+ &. #pro (d2ace provars)
+ in
+ subenvPlus SF SF (d2eM (select SMerge des)) (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) easub $ \subfa _ _ plus_f_a ->
+ Ret (bconcat ea0 proPrimalBinds'
+ `bpush` weakenExpr (autoWeak library (#a0 :++: #d1env) ((#propr :++: #a0) :++: #d1env)) ea1
+ `bpush` emap (weakenExpr (autoWeak library (#x :++: #d1env) (#x :++: #parr :++: (#propr :++: #a0) :++: #d1env))
+ (letBinds ef0 $
+ EPair ext ef1 ef1tape))
+ (EVar ext (STArr ndim (d1 t1)) IZ)
+ `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 t2) ttape)) IZ))
+ (SEYesR (SENo (SENo (subenvConcat easubtape (subenvAll (d1e provars))))))
+ (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 t2) ttape)) (IS IZ)))
+ subfa
+ (let layout = #darr :++: #tapearr :++: (#propr :++: #atapebinds) :++: #d2acEnv in
+ elet
+ (wrapAccum (autoWeak library #propr layout) $
+ emap (elet (EFst ext (EVar ext (STPair (applySparse sdElt (d2 t2)) ttape) IZ)) $
+ elet (ESnd ext (EVar ext (STPair (applySparse sdElt (d2 t2)) ttape) (IS IZ))) $
+ weakenExpr (autoWeak library (#tape :++: #dy :++: #pro :++: #d2acEnv)
+ (#tape :++: #dy :++: #dytape :++: #pro :++: layout))
+ ef2)
+ (ezip (EVar ext (STArr ndim (applySparse sdElt (d2 t2))) (autoWeak library #darr (#pro :++: layout) @> IZ))
+ (EVar ext (STArr ndim ttape) (autoWeak library #tapearr (#pro :++: layout) @> IZ)))) $
+ plus_f_a
+ (ESnd ext (evar IZ))
+ (weakenExpr (WCopy (autoWeak library (#atapebinds :++: #d2acEnv) layout))
+ (subst0 (EFst ext (EVar ext (STPair (STArr ndim (typeOf ef2)) (tTup (d2e provars))) IZ))
+ ea2)))
+ }
EFold1Inner _ commut origef ex₀ earr
| SpArr @_ @sdElt sdElt <- sd
, STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr
, Rets bindsx₀a subtapex₀a (RetPair ex₀1 subx₀ ex₀2 `SCons` RetPair ea1 suba ea2 `SCons` SNil)
<- retConcat des $ toSingleRet (drev des accumMap (spDense (d2M eltty)) ex₀) `SCons` toSingleRet (drev des accumMap (spDense (SMTArr (SS ndim) (d2M eltty))) earr) `SCons` SNil ->
- deleteUnused (descrList des) (occEnvPopSome (occEnvPopSome (occCountAll origef))) $ \(usedSub :: Subenv env env') ->
- let ef = unsafeWeakenWithSubenv (SEYesR (SEYesR usedSub)) origef in
- subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed ->
- accumPromote (d2 eltty) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->
- let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in
- let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in
- let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in
- let (mergePrimalBindings', _) = weakenBindingsE (sinkWithBindings bindsx₀a) mergePrimalBindings in
- case drev (prodes `DPush` (eltty, Nothing, SMerge) `DPush` (eltty, Nothing, SMerge)) accumMapPro (spDense (d2M eltty)) ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 ->
- let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in
- let bogTy = STArr (SS ndim) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf)))
+ drevLambda des accumMap (STPair eltty eltty, SMerge) (spDense (d2M eltty)) origef $ \(provars :: SList _ envPro) efsub proPrimalBinds ef0 ef1 (ef1tape :: Ex _ ef_tape) spEf wrapAccum ef2 ->
+ let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings bindsx₀a) proPrimalBinds in
+ let bogEltTy = STPair (STPair (d1 eltty) (d1 eltty)) (typeOf ef1tape)
+ bogTy = STArr (SS ndim) bogEltTy
primalTy = STPair (STArr ndim (d1 eltty)) bogTy
- zipPrimalTy = STPair (d1 eltty) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf)))
- library = #xy (d1 eltty `SCons` d1 eltty `SCons` SNil)
+ library = #xy (STPair (d1 eltty) (d1 eltty) `SCons` SNil)
&. #parr (auto1 @(TArr (S n) (D1 elt)))
&. #px₀ (auto1 @(D1 elt))
&. #px (auto1 @(D1 elt))
@@ -1118,70 +1154,53 @@ drev des accumMap sd = \case
&. #x₀abinds (bindingsBinds bindsx₀a)
&. #fbinds (bindingsBinds ef0)
&. #x₀atapebinds (subList (bindingsBinds bindsx₀a) subtapex₀a)
- &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf)
- &. #ftape (auto1 @(Tape e_tape))
- &. #primalzip (zipPrimalTy `SCons` SNil)
- &. #efPrerebinds efPrerebinds
- &. #propr (d1e envPro)
+ &. #ftape (auto1 @ef_tape)
+ &. #bogelt (bogEltTy `SCons` SNil)
+ &. #propr (d1e provars)
&. #d1env (desD1E des)
- &. #d1env' (desD1E usedDes)
- &. #d2acUsed (d2ace (select SAccum usedDes))
&. #d2acEnv (d2ace (select SAccum des))
- &. #d2acPro (d2ace envPro)
+ &. #d2acPro (d2ace provars)
&. #foldd2res (auto1 @(TPair (TPair (D2 elt) (TArr (S n) (D2 elt))) (Tup (D2E envPro))))
wOverPrimalBindings = autoWeak library (#x₀abinds :++: #d1env) ((#propr :++: #x₀abinds) :++: #d1env) in
subenvPlus SF SF (d2eM (select SMerge des)) subx₀ suba $ \subx₀a _ _ plus_x₀_a ->
- subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) $ \subx₀af _ _ plus_x₀a_f ->
- Ret (bconcat bindsx₀a mergePrimalBindings'
+ subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) $ \subx₀af _ _ plus_x₀a_f ->
+ Ret (bconcat bindsx₀a proPrimalBinds'
`bpush` weakenExpr wOverPrimalBindings ex₀1
`bpush` d2zeroInfo eltty (EVar ext (d1 eltty) IZ)
`bpush` weakenExpr (WSink .> WSink .> wOverPrimalBindings) ea1
`bpush` EFold1InnerD1 ext commut
(let layout = #xy :++: #parr :++: #pzi :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env in
- letBinds (fst (weakenBindingsE (autoWeak library
- (#xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
- layout)
- ef0)) $
- elet (weakenExpr (autoWeak library (#fbinds :++: #xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
- (#fbinds :++: layout))
- ef1) $
- EPair ext
- (evar IZ)
- (EPair ext
- (evar IZ)
- (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#px :++: #fbinds :++: layout)))))
+ weakenExpr (autoWeak library (#xy :++: #d1env) layout)
+ (letBinds ef0 $
+ EPair ext -- (out, ((in1, in2), tape)); the "additional stores" are ((in1, in2), tape)
+ ef1
+ (EPair ext
+ (EVar ext (STPair (d1 eltty) (d1 eltty)) (autoWeak library #xy (#fbinds :++: #xy :++: #d1env) @> IZ))
+ ef1tape)))
(EVar ext (d1 eltty) (IS (IS IZ)))
(EVar ext (STArr (SS ndim) (d1 eltty)) IZ))
- (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e envPro)))))))
+ (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e provars)))))))
(EFst ext (EVar ext primalTy IZ))
subx₀af
(let layout1 = #darr :++: #primal :++: #parr :++: #pzi :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in
elet
- (uninvertTup (d2e envPro) (STPair (STArr ndim (d2 eltty)) (STArr (SS ndim) (d2 eltty))) $
- makeAccumulators (autoWeak library #propr layout1) envPro $
- let layout2 = #d2acPro :++: layout1 in
- EFold1InnerD2 ext commut
- (elet (ESnd ext (ESnd ext (EVar ext zipPrimalTy (IS IZ)))) $
- elet (EFst ext (ESnd ext (EVar ext zipPrimalTy (IS (IS IZ))))) $
- elet (EFst ext (EVar ext zipPrimalTy (IS (IS (IS IZ))))) $
- letBinds (efRebinds (IS (IS IZ))) $
- let layout3 = (#ftapebinds :++: #efPrerebinds) :++: #xy :++: #ftape :++: #d :++: #primalzip :++: layout2 in
- elet (expandSubenvZeros (autoWeak library #xy layout3) (eltty `SCons` eltty `SCons` SNil) subEf $
- weakenExpr (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) layout3
- .> wPro (subList (bindingsBinds ef0) subtapeEf))
- ef2) $
- EPair ext (ESnd ext (EFst ext (evar IZ))) (ESnd ext (evar IZ)))
- (ezip
- (EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ))
- (ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))
- (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ)))
- (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ))
- (EFst ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))) $
+ (wrapAccum (autoWeak library #propr layout1) $
+ let layout2 = #d2acPro :++: layout1 in
+ EFold1InnerD2 ext commut
+ (elet (ESnd ext (EVar ext bogEltTy (IS IZ))) $
+ let layout3 = #ftape :++: #d :++: #bogelt :++: layout2 in
+ expandSparse (STPair eltty eltty) spEf (EFst ext (EVar ext bogEltTy (IS (IS IZ)))) $
+ weakenExpr (autoWeak library (#ftape :++: #d :++: #d2acPro :++: #d2acEnv) layout3) ef2)
+ (ESnd ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))
+ (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ)))
+ (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ))
+ (EFst ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))))) $
plus_x₀a_f
(plus_x₀_a
(elet (EIdx0 ext
(EFold1Inner ext Commut
- (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
+ (let t = STPair (d2 eltty) (d2 eltty)
+ in EPlus ext (d2M eltty) (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ)))
(EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (WSink .> autoWeak library #pzi layout1 @> IZ)))
(eflatten (EFst ext (EFst ext (evar IZ)))))) $
weakenExpr (WCopy (WSink .> autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1))
@@ -1189,7 +1208,6 @@ drev des accumMap sd = \case
(weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $
subst0 (ESnd ext (EFst ext (evar IZ))) ea2))
(ESnd ext (evar IZ)))
- }
EUnit _ e
| SpArr sdElt <- sd
@@ -1213,9 +1231,8 @@ drev des accumMap sd = \case
(EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1)
sub
(ELet ext (EFold1Inner ext Commut
- (sparsePlus (d2M eltty) sdElt'
- (EVar ext (applySparse sdElt' (d2 eltty)) (IS IZ))
- (EVar ext (applySparse sdElt' (d2 eltty)) IZ))
+ (let t = STPair (applySparse sdElt' (d2 eltty)) (applySparse sdElt' (d2 eltty))
+ in sparsePlus (d2M eltty) sdElt' (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ)))
(inj2 (ENil ext))
(emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $
weakenExpr (WCopy WSink) e2)
@@ -1494,7 +1511,7 @@ drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False)
D1E provars :> env'
-> Ex (Append (D2AcE provars) env') b
-> Ex ( env') (TPair b (Tup (D2E provars))))
- -> Ex (dt : tape : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a'
+ -> Ex (tape : dt : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a'
-> r)
-> r
drevLambda des accumMap (argty, argsto) sd origef k =
@@ -1535,10 +1552,10 @@ drevLambda des accumMap (argty, argsto) sd origef k =
uninvertTup (d2e envPro) (typeOf body) $
makeAccumulators wpro1 envPro $
body)
- (letBinds (efRebinds (IS IZ)) $
+ (letBinds (efRebinds IZ) $
weakenExpr
(autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
- ((#ftapebinds :++: #efPrerebinds) :++: #d :++: #ftape :++: #d2acPro :++: #d2acEnv)
+ ((#ftapebinds :++: #efPrerebinds) :++: #ftape :++: #d :++: #d2acPro :++: #d2acEnv)
.> wPro (subList (bindingsBinds ef0) subtapeEf))
(getSparseArg ef2))
}}