diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-10-22 22:02:06 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-10-22 22:02:06 +0200 |
commit | 79e072eddf0ec2a97ca455c27cb5ff6f2132bbab (patch) | |
tree | 2099dc7e9d9a1109d844bca73277ca82983a02c2 /src/CHAD.hs | |
parent | e7d7ac0fd8b81c1d6fae9ab7c1e4654133c631ea (diff) |
Differentiate Replicate
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index d05e77f..786de07 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1037,6 +1037,19 @@ drev des = \case (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $ weakenExpr (WCopy WSink) e2) + EReplicate1Inner _ en e + -- We're allowed to ignore en2 here because the output of 'ei' is discrete. + | Rets binds (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil) + <- retConcat des $ drev des en `SCons` drev des e `SCons` SNil + , let STArr ndim eltty = typeOf e -> + Ret (binds `BPush` (d1 (typeOf e), e1)) + (weakenExpr WSink $ EReplicate1Inner ext en1 e1) + sub + (ELet ext (EFold1Inner ext (EPlus eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) + (EZero eltty) + (EVar ext (STArr (SS ndim) (d2 eltty)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + EIdx0 _ e | Ret e0 e1 sub e2 <- drev des e , STArr _ t <- typeOf e -> @@ -1097,7 +1110,6 @@ drev des = \case weakenExpr (WCopy (WSink .> WSink)) e2) -- These should be the next to be implemented, I think - EReplicate1Inner{} -> err_unsupported "EReplicate1Inner" EFold1Inner{} -> err_unsupported "EFold1Inner" ENothing{} -> err_unsupported "ENothing" |