diff options
Diffstat (limited to 'src/CHAD.hs')
| -rw-r--r-- | src/CHAD.hs | 15 | 
1 files changed, 15 insertions, 0 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index 93fabf9..04c4231 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1327,6 +1327,21 @@ drev des accumMap sd = \case    EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e    EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e +  EReshape _ n esh e +    | SpArr sd' <- sd +    , STArr orign t <- typeOf e +    , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e +    , Refl <- indexTupD1Id n -> +    Ret (e0 `bpush` e1 +            `bpush` EShape ext (EVar ext (STArr orign (d1 t)) IZ)) +        (SEYesR (SENo subtape)) +        (EReshape ext n (weakenExpr (WSink .> WSink .> wSinks (bindingsBinds e0)) (drevPrimal des esh)) +                        (EVar ext (STArr orign (d1 t)) (IS IZ))) +        sub +        (elet (EReshape ext orign (EVar ext (tTup (sreplicate orign tIx)) (IS IZ)) +                                  (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ +          weakenExpr (WCopy (WSink .> WSink)) e2) +    ENothing{} -> err_unsupported "ENothing"    EJust{} -> err_unsupported "EJust"    EMaybe{} -> err_unsupported "EMaybe"  | 
