summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs56
1 files changed, 29 insertions, 27 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 8b9f17a..45fcc82 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1016,33 +1016,35 @@ drev des = \case
(EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ)))
(subenvCompose subMergeUsed proSub)
(let sinkOverEnvPro = wSinks @(D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in
- ESnd ext $
- uninvertTup (d2e envPro) (STArr ndim STNil) $
- makeAccumulators @_ @_ @(TArr ndim TNil) envPro $
- EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $
- -- the cotangent for this element
- ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ))
- (EVar ext shty IZ)) $
- -- the tape for this element
- ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
- (EVar ext shty (IS IZ))) $
- let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ
- in letBinds rebinds $
- weakenExpr (autoWeak (#d (auto1 @(D2 eltty))
- &. #pro (d2ace envPro)
- &. #etape (subList (bindingsBinds e0) subtapeE)
- &. #prerebinds prerebinds
- &. #tape (tapety `SCons` SNil)
- &. #ix (shty `SCons` SNil)
- &. #darr (STArr ndim (d2 eltty) `SCons` SNil)
- &. #tapearr (STArr ndim tapety `SCons` SNil)
- &. #sh (shty `SCons` SNil)
- &. #d2acUsed (d2ace (select SAccum usedDes))
- &. #d2acEnv (d2ace (select SAccum des)))
- (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
- ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #d2acEnv)
- .> wPro (subList (bindingsBinds e0) subtapeE))
- e2)
+ eif (eshapeEmpty ndim (EShape ext (EVar ext (STArr ndim (d2 eltty)) IZ)))
+ (zeroTup envPro)
+ (ESnd ext $
+ uninvertTup (d2e envPro) (STArr ndim STNil) $
+ makeAccumulators @_ @_ @(TArr ndim TNil) envPro $
+ EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $
+ -- the cotangent for this element
+ ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ))
+ (EVar ext shty IZ)) $
+ -- the tape for this element
+ ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
+ (EVar ext shty (IS IZ))) $
+ let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ
+ in letBinds rebinds $
+ weakenExpr (autoWeak (#d (auto1 @(D2 eltty))
+ &. #pro (d2ace envPro)
+ &. #etape (subList (bindingsBinds e0) subtapeE)
+ &. #prerebinds prerebinds
+ &. #tape (tapety `SCons` SNil)
+ &. #ix (shty `SCons` SNil)
+ &. #darr (STArr ndim (d2 eltty) `SCons` SNil)
+ &. #tapearr (STArr ndim tapety `SCons` SNil)
+ &. #sh (shty `SCons` SNil)
+ &. #d2acUsed (d2ace (select SAccum usedDes))
+ &. #d2acEnv (d2ace (select SAccum des)))
+ (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
+ ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #d2acEnv)
+ .> wPro (subList (bindingsBinds e0) subtapeE))
+ e2))
}}
EUnit _ e