diff options
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index d05e77f..786de07 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1037,6 +1037,19 @@ drev des = \case (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $ weakenExpr (WCopy WSink) e2) + EReplicate1Inner _ en e + -- We're allowed to ignore en2 here because the output of 'ei' is discrete. + | Rets binds (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil) + <- retConcat des $ drev des en `SCons` drev des e `SCons` SNil + , let STArr ndim eltty = typeOf e -> + Ret (binds `BPush` (d1 (typeOf e), e1)) + (weakenExpr WSink $ EReplicate1Inner ext en1 e1) + sub + (ELet ext (EFold1Inner ext (EPlus eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) + (EZero eltty) + (EVar ext (STArr (SS ndim) (d2 eltty)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + EIdx0 _ e | Ret e0 e1 sub e2 <- drev des e , STArr _ t <- typeOf e -> @@ -1097,7 +1110,6 @@ drev des = \case weakenExpr (WCopy (WSink .> WSink)) e2) -- These should be the next to be implemented, I think - EReplicate1Inner{} -> err_unsupported "EReplicate1Inner" EFold1Inner{} -> err_unsupported "EFold1Inner" ENothing{} -> err_unsupported "ENothing" |