diff options
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 27 |
1 files changed, 14 insertions, 13 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index 4694ac4..e77dbe7 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -981,8 +981,8 @@ drev des = \case (#e0 :++: #ix :++: #sh :++: #she0 :++: #d1env))))) (EBuild ext ndim (EVar ext shty (IS IZ)) - (ELet ext (EIdx ext ndim (EVar ext (STArr ndim tapety) (IS IZ)) - (EVar ext shty IZ)) $ + (ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (IS IZ)) + (EVar ext shty IZ)) $ let (rebinds, prerebinds) = reconstructBindings (bindingsBinds e0) IZ in letBinds rebinds $ weakenExpr (autoWeak (#ix (shty `SCons` SNil) @@ -1004,11 +1004,11 @@ drev des = \case makeAccumulators @_ @_ @(TArr ndim TNil) envPro $ EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $ -- the cotangent for this element - ELet ext (EIdx ext ndim (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ)) - (EVar ext shty IZ)) $ + ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ)) + (EVar ext shty IZ)) $ -- the tape for this element - ELet ext (EIdx ext ndim (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) - (EVar ext shty (IS IZ))) $ + ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) + (EVar ext shty (IS IZ))) $ let (rebinds, prerebinds) = reconstructBindings (bindingsBinds e0) IZ in letBinds rebinds $ weakenExpr (autoWeak (#d (auto1 @(D2 eltty)) @@ -1073,19 +1073,20 @@ drev des = \case (EVar ext (STArr n (d2 eltty)) (IS IZ))) $ weakenExpr (WCopy (WSink .> WSink)) e2) - EIdx _ n e ei + EIdx _ e ei -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. | Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil - , STArr _ eltty <- typeOf e + , STArr n eltty <- typeOf e , Refl <- indexTupD1Id n -> - Ret (binds `BPush` (STArr n (d1 eltty), e1)) - (EIdx ext n (EVar ext (STArr n (d1 eltty)) IZ) - (weakenExpr WSink ei1)) + Ret (binds `BPush` (STArr n (d1 eltty), e1) + `BPush` (tTup (sreplicate n tIx), EShape ext (EVar ext (typeOf e1) IZ))) + (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS IZ)) + (weakenExpr (WSink .> WSink) ei1)) sub - (ELet ext (EBuild ext n (EShape ext (EVar ext (STArr n (d1 eltty)) (IS IZ))) + (ELet ext (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) (IS IZ)) (EVar ext (d2 eltty) (IS IZ))) $ - weakenExpr (WCopy (WSink .> WSink)) e2) + weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) EShape _ e -- Allowed to ignore e2 here because the output of EShape is discrete, |