aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-03 23:09:37 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-03 23:10:23 +0100
commit81d88dbc430ca6ec8390636f8b7162887b390873 (patch)
tree849c126fad3b923c2e5b815aa5c8488907bc2318 /src/CHAD.hs
parent2ca218d2e97e521bcc49dea8f4774737ba083ede (diff)
WIP map + zip
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs29
1 files changed, 29 insertions, 0 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 7594a0f..67ffe12 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1116,6 +1116,8 @@ drev des accumMap sd = \case
e2)
}}
+ EMap{} -> undefined
+
EFold1Inner _ commut origef ex₀ earr
| SpArr @_ @sdElt sdElt <- sd
, STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr
@@ -1346,6 +1348,33 @@ drev des accumMap sd = \case
(EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $
weakenExpr (WCopy (WSink .> WSink)) e2)
+ EZip _ a b
+ | SpArr sd' <- sd
+ , STArr n t1 <- typeOf a
+ , STArr _ t2 <- typeOf b ->
+ splitSparsePair (STPair (d2 t1) (d2 t2)) sd' $ \sd1 sd2 pairSplitE ->
+ case retConcat des (toSingleRet (drev des accumMap (SpArr sd1) a) `SCons`
+ toSingleRet (drev des accumMap (SpArr sd2) b) `SCons` SNil) of
+ { Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B ->
+ Ret binds
+ subtape
+ (EZip ext a1 b1)
+ subBoth
+ (case pairSplitE of
+ Left Refl ->
+ let t' = STArr n (STPair (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 t2))) in
+ plus_A_B
+ (elet (emap (EFst ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) a2)
+ (elet (emap (ESnd ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) b2)
+ Right f -> f IZ $ \wrapPair pick1 pick2 ->
+ elet (emap (wrapPair (EPair ext pick1 pick2))
+ (EVar ext (applySparse (SpArr sd') (STArr n (STPair (d2 t1) (d2 t2)))) IZ)) $
+ plus_A_B
+ (elet (emap (EFst ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) a2)
+ (elet (emap (ESnd ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) b2))
+ }
+
ENothing{} -> err_unsupported "ENothing"
EJust{} -> err_unsupported "EJust"
EMaybe{} -> err_unsupported "EMaybe"