aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/CHAD.hs37
1 files changed, 17 insertions, 20 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 9da5395..a37edff 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1043,12 +1043,12 @@ drev des accumMap sd = \case
(subenvNone (d2e (select SMerge des)))
(ENil ext)
- EBuild _ (ndim :: SNat ndim) she (orige :: Expr _ _ eltty)
+ EBuild _ (ndim :: SNat ndim) she (ef :: Expr _ _ eltty)
| SpArr @_ @sdElt sdElt <- sd
- , let eltty = typeOf orige
+ , let eltty = typeOf ef
, shty :: STy shty <- tTup (sreplicate ndim tIx)
, Refl <- indexTupD1Id ndim ->
- drevLambda des accumMap (shty, SDiscr) sdElt orige $ \(provars :: SList _ envPro) esub proPrimalBinds e0 e1 (e1tape :: Ex _ e_tape) _ wrapAccum e2 ->
+ drevLambda des accumMap (shty, SDiscr) sdElt ef $ \(provars :: SList _ envPro) esub proPrimalBinds e0 e1 (e1tape :: Ex _ e_tape) _ wrapAccum e2 ->
let library = #ix (shty `SCons` SNil)
&. #e0 (bindingsBinds e0)
&. #propr (d1e provars)
@@ -1060,15 +1060,11 @@ drev des accumMap sd = \case
&. #darr (auto1 @(TArr ndim sdElt))
&. #tapearr (auto1 @(TArr ndim e_tape)) in
Ret (proPrimalBinds
- `bpush` EBuild ext ndim
- (weakenExpr (wSinks (d1e provars)) (drevPrimal des she))
- (letBinds (fst (weakenBindingsE (autoWeak library
- (#ix :++: #d1env)
- (#ix :++: #propr :++: #d1env))
- e0)) $
- weakenExpr (autoWeak library (#e0 :++: #ix :++: #d1env)
- (#e0 :++: #ix :++: #propr :++: #d1env))
- (EPair ext e1 e1tape))
+ `bpush` weakenExpr (wSinks (d1e provars))
+ (EBuild ext ndim
+ (drevPrimal des she)
+ (letBinds e0 $
+ EPair ext e1 e1tape))
`bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) IZ))
(SEYesR (SENo (subenvAll (d1e provars))))
(emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) (IS IZ)))
@@ -1094,7 +1090,7 @@ drev des accumMap sd = \case
, STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr
, Rets bindsx₀a subtapex₀a (RetPair ex₀1 subx₀ ex₀2 `SCons` RetPair ea1 suba ea2 `SCons` SNil)
<- retConcat des $ toSingleRet (drev des accumMap (spDense (d2M eltty)) ex₀) `SCons` toSingleRet (drev des accumMap (spDense (SMTArr (SS ndim) (d2M eltty))) earr) `SCons` SNil ->
- drevLambda des accumMap (STPair eltty eltty, SMerge) (spDense (d2M eltty)) origef $ \(provars :: SList _ envPro) efsub proPrimalBinds ef0 ef1 (ef1tape :: Ex _ ef_tape) subEf wrapAccum ef2 ->
+ drevLambda des accumMap (STPair eltty eltty, SMerge) (spDense (d2M eltty)) origef $ \(provars :: SList _ envPro) efsub proPrimalBinds ef0 ef1 (ef1tape :: Ex _ ef_tape) spEf wrapAccum ef2 ->
let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings bindsx₀a) proPrimalBinds in
let bogEltTy = STPair (STPair (d1 eltty) (d1 eltty)) (typeOf ef1tape)
bogTy = STArr (SS ndim) bogEltTy
@@ -1126,12 +1122,13 @@ drev des accumMap sd = \case
`bpush` weakenExpr (WSink .> WSink .> wOverPrimalBindings) ea1
`bpush` EFold1InnerD1 ext commut
(let layout = #xy :++: #parr :++: #pzi :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env in
- letBinds (fst (weakenBindingsE (autoWeak library (#xy :++: #d1env) layout) ef0)) $
- EPair ext -- (out, ((in1, in2), tape)); the "additional stores" are ((in1, in2), tape)
- (weakenExpr (autoWeak library (#fbinds :++: #xy :++: #d1env) (#fbinds :++: layout)) ef1)
- (EPair ext
- (EVar ext (STPair (d1 eltty) (d1 eltty)) (autoWeak library #xy (#fbinds :++: layout) @> IZ))
- (weakenExpr (autoWeak library (#fbinds :++: #xy :++: #d1env) (#fbinds :++: layout)) ef1tape)))
+ weakenExpr (autoWeak library (#xy :++: #d1env) layout)
+ (letBinds ef0 $
+ EPair ext -- (out, ((in1, in2), tape)); the "additional stores" are ((in1, in2), tape)
+ ef1
+ (EPair ext
+ (EVar ext (STPair (d1 eltty) (d1 eltty)) (autoWeak library #xy (#fbinds :++: #xy :++: #d1env) @> IZ))
+ ef1tape)))
(EVar ext (d1 eltty) (IS (IS IZ)))
(EVar ext (STArr (SS ndim) (d1 eltty)) IZ))
(SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e provars)))))))
@@ -1144,7 +1141,7 @@ drev des accumMap sd = \case
EFold1InnerD2 ext commut
(elet (ESnd ext (EVar ext bogEltTy (IS IZ))) $
let layout3 = #ftape :++: #d :++: #bogelt :++: layout2 in
- expandSparse (STPair eltty eltty) subEf (EFst ext (EVar ext bogEltTy (IS (IS IZ)))) $
+ expandSparse (STPair eltty eltty) spEf (EFst ext (EVar ext bogEltTy (IS (IS IZ)))) $
weakenExpr (autoWeak library (#ftape :++: #d :++: #d2acPro :++: #d2acEnv) layout3) ef2)
(ESnd ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))
(ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ)))