From 4d456e4d34b1e4fb3725051d1b8a0c376b704692 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 30 Oct 2025 15:56:35 +0100 Subject: Implement reshape --- src/CHAD.hs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) (limited to 'src/CHAD.hs') diff --git a/src/CHAD.hs b/src/CHAD.hs index 93fabf9..04c4231 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1327,6 +1327,21 @@ drev des accumMap sd = \case EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e + EReshape _ n esh e + | SpArr sd' <- sd + , STArr orign t <- typeOf e + , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e + , Refl <- indexTupD1Id n -> + Ret (e0 `bpush` e1 + `bpush` EShape ext (EVar ext (STArr orign (d1 t)) IZ)) + (SEYesR (SENo subtape)) + (EReshape ext n (weakenExpr (WSink .> WSink .> wSinks (bindingsBinds e0)) (drevPrimal des esh)) + (EVar ext (STArr orign (d1 t)) (IS IZ))) + sub + (elet (EReshape ext orign (EVar ext (tTup (sreplicate orign tIx)) (IS IZ)) + (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + ENothing{} -> err_unsupported "ENothing" EJust{} -> err_unsupported "EJust" EMaybe{} -> err_unsupported "EMaybe" -- cgit v1.2.3-70-g09d2