summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs27
1 files changed, 14 insertions, 13 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 4694ac4..e77dbe7 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -981,8 +981,8 @@ drev des = \case
(#e0 :++: #ix :++: #sh :++: #she0 :++: #d1env)))))
(EBuild ext ndim
(EVar ext shty (IS IZ))
- (ELet ext (EIdx ext ndim (EVar ext (STArr ndim tapety) (IS IZ))
- (EVar ext shty IZ)) $
+ (ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (IS IZ))
+ (EVar ext shty IZ)) $
let (rebinds, prerebinds) = reconstructBindings (bindingsBinds e0) IZ
in letBinds rebinds $
weakenExpr (autoWeak (#ix (shty `SCons` SNil)
@@ -1004,11 +1004,11 @@ drev des = \case
makeAccumulators @_ @_ @(TArr ndim TNil) envPro $
EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $
-- the cotangent for this element
- ELet ext (EIdx ext ndim (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ))
- (EVar ext shty IZ)) $
+ ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ))
+ (EVar ext shty IZ)) $
-- the tape for this element
- ELet ext (EIdx ext ndim (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
- (EVar ext shty (IS IZ))) $
+ ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
+ (EVar ext shty (IS IZ))) $
let (rebinds, prerebinds) = reconstructBindings (bindingsBinds e0) IZ
in letBinds rebinds $
weakenExpr (autoWeak (#d (auto1 @(D2 eltty))
@@ -1073,19 +1073,20 @@ drev des = \case
(EVar ext (STArr n (d2 eltty)) (IS IZ))) $
weakenExpr (WCopy (WSink .> WSink)) e2)
- EIdx _ n e ei
+ EIdx _ e ei
-- We're allowed to ignore ei2 here because the output of 'ei' is discrete.
| Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)
<- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil
- , STArr _ eltty <- typeOf e
+ , STArr n eltty <- typeOf e
, Refl <- indexTupD1Id n ->
- Ret (binds `BPush` (STArr n (d1 eltty), e1))
- (EIdx ext n (EVar ext (STArr n (d1 eltty)) IZ)
- (weakenExpr WSink ei1))
+ Ret (binds `BPush` (STArr n (d1 eltty), e1)
+ `BPush` (tTup (sreplicate n tIx), EShape ext (EVar ext (typeOf e1) IZ)))
+ (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS IZ))
+ (weakenExpr (WSink .> WSink) ei1))
sub
- (ELet ext (EBuild ext n (EShape ext (EVar ext (STArr n (d1 eltty)) (IS IZ)))
+ (ELet ext (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) (IS IZ))
(EVar ext (d2 eltty) (IS IZ))) $
- weakenExpr (WCopy (WSink .> WSink)) e2)
+ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
EShape _ e
-- Allowed to ignore e2 here because the output of EShape is discrete,