From 81d88dbc430ca6ec8390636f8b7162887b390873 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 3 Nov 2025 23:09:37 +0100 Subject: WIP map + zip --- src/CHAD.hs | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) (limited to 'src/CHAD.hs') diff --git a/src/CHAD.hs b/src/CHAD.hs index 7594a0f..67ffe12 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1116,6 +1116,8 @@ drev des accumMap sd = \case e2) }} + EMap{} -> undefined + EFold1Inner _ commut origef ex₀ earr | SpArr @_ @sdElt sdElt <- sd , STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr @@ -1346,6 +1348,33 @@ drev des accumMap sd = \case (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ weakenExpr (WCopy (WSink .> WSink)) e2) + EZip _ a b + | SpArr sd' <- sd + , STArr n t1 <- typeOf a + , STArr _ t2 <- typeOf b -> + splitSparsePair (STPair (d2 t1) (d2 t2)) sd' $ \sd1 sd2 pairSplitE -> + case retConcat des (toSingleRet (drev des accumMap (SpArr sd1) a) `SCons` + toSingleRet (drev des accumMap (SpArr sd2) b) `SCons` SNil) of + { Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) -> + subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B -> + Ret binds + subtape + (EZip ext a1 b1) + subBoth + (case pairSplitE of + Left Refl -> + let t' = STArr n (STPair (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 t2))) in + plus_A_B + (elet (emap (EFst ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) a2) + (elet (emap (ESnd ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) b2) + Right f -> f IZ $ \wrapPair pick1 pick2 -> + elet (emap (wrapPair (EPair ext pick1 pick2)) + (EVar ext (applySparse (SpArr sd') (STArr n (STPair (d2 t1) (d2 t2)))) IZ)) $ + plus_A_B + (elet (emap (EFst ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) a2) + (elet (emap (ESnd ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) b2)) + } + ENothing{} -> err_unsupported "ENothing" EJust{} -> err_unsupported "EJust" EMaybe{} -> err_unsupported "EMaybe" -- cgit v1.2.3-70-g09d2