aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs15
1 files changed, 15 insertions, 0 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 93fabf9..04c4231 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1327,6 +1327,21 @@ drev des accumMap sd = \case
EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e
EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e
+ EReshape _ n esh e
+ | SpArr sd' <- sd
+ , STArr orign t <- typeOf e
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e
+ , Refl <- indexTupD1Id n ->
+ Ret (e0 `bpush` e1
+ `bpush` EShape ext (EVar ext (STArr orign (d1 t)) IZ))
+ (SEYesR (SENo subtape))
+ (EReshape ext n (weakenExpr (WSink .> WSink .> wSinks (bindingsBinds e0)) (drevPrimal des esh))
+ (EVar ext (STArr orign (d1 t)) (IS IZ)))
+ sub
+ (elet (EReshape ext orign (EVar ext (tTup (sreplicate orign tIx)) (IS IZ))
+ (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
+
ENothing{} -> err_unsupported "ENothing"
EJust{} -> err_unsupported "EJust"
EMaybe{} -> err_unsupported "EMaybe"