summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs14
1 files changed, 13 insertions, 1 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index d05e77f..786de07 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1037,6 +1037,19 @@ drev des = \case
(ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $
weakenExpr (WCopy WSink) e2)
+ EReplicate1Inner _ en e
+ -- We're allowed to ignore en2 here because the output of 'ei' is discrete.
+ | Rets binds (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil)
+ <- retConcat des $ drev des en `SCons` drev des e `SCons` SNil
+ , let STArr ndim eltty = typeOf e ->
+ Ret (binds `BPush` (d1 (typeOf e), e1))
+ (weakenExpr WSink $ EReplicate1Inner ext en1 e1)
+ sub
+ (ELet ext (EFold1Inner ext (EPlus eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
+ (EZero eltty)
+ (EVar ext (STArr (SS ndim) (d2 eltty)) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
+
EIdx0 _ e
| Ret e0 e1 sub e2 <- drev des e
, STArr _ t <- typeOf e ->
@@ -1097,7 +1110,6 @@ drev des = \case
weakenExpr (WCopy (WSink .> WSink)) e2)
-- These should be the next to be implemented, I think
- EReplicate1Inner{} -> err_unsupported "EReplicate1Inner"
EFold1Inner{} -> err_unsupported "EFold1Inner"
ENothing{} -> err_unsupported "ENothing"