summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-06-18 10:11:12 +0200
committerTom Smeding <tom@tomsmeding.com>2025-06-18 10:11:12 +0200
commitfe80b31555c27f038b20eb84eb1e747781d7c76b (patch)
treebd6db261f391459ca638557e74cb101560ee2aab
parent58f68a4d077c2d58c3974ad12853207512277a33 (diff)
Don't destroy effects in sparse plus
-rw-r--r--src/AST/Sparse.hs20
-rw-r--r--src/CHAD.hs3
-rw-r--r--test/Main.hs15
3 files changed, 23 insertions, 15 deletions
diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs
index f0a1f2a..0c5bdb0 100644
--- a/src/AST/Sparse.hs
+++ b/src/AST/Sparse.hs
@@ -66,6 +66,9 @@ withInj2 (Inj f) (Inj g) k = Inj (k f g)
withInj2 Noinj _ _ = Noinj
withInj2 _ Noinj _ = Noinj
+use :: Ex env a -> Ex env b -> Ex env b
+use a b = elet a $ weakenExpr WSink b
+
-- | This function produces quadratically-sized code in the presence of nested
-- dynamic sparsity. TODO can this be improved?
sparsePlusS
@@ -77,16 +80,17 @@ sparsePlusS
-> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3)
-> r)
-> r
--- nil override
-sparsePlusS _ _ SMTNil _ _ k = k SpAbsent (Inj $ \_ -> ENil ext) (Inj $ \_ -> ENil ext) (\_ _ -> ENil ext)
+-- nil override (but don't destroy effects!)
+sparsePlusS _ _ SMTNil _ _ k =
+ k SpAbsent (Inj $ \a -> use a $ ENil ext) (Inj $ \b -> use b $ ENil ext) (\a b -> use a $ use b $ ENil ext)
-- simplifications
sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k =
sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus ->
- k sp3 (withInj minj1 $ \inj1 -> \_ -> inj1 (ENil ext)) minj2 (\_ b -> plus (ENil ext) b)
+ k sp3 (withInj minj1 $ \inj1 -> \a -> use a $ inj1 (ENil ext)) minj2 (\a b -> use a $ plus (ENil ext) b)
sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k =
sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus ->
- k sp3 minj1 (withInj minj2 $ \inj2 -> \_ -> inj2 (ENil ext)) (\a _ -> plus a (ENil ext))
+ k sp3 minj1 (withInj minj2 $ \inj2 -> \b -> use b $ inj2 (ENil ext)) (\a b -> use b $ plus a (ENil ext))
sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k =
let ta = applySparse sp1 (fromSMTy t) in
@@ -144,13 +148,13 @@ sparsePlusS _ _ t sp1 sp2 k
= k (spDense t) (Inj id) (Inj id) (\a b -> EPlus ext t a b)
-- handle absents
-sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\_ b -> b)
+sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b)
sparsePlusS ST _ t SpAbsent sp2 k =
- k (SpSparse sp2) (Inj $ \_ -> ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\_ b -> EJust ext b)
+ k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b)
-sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a _ -> a)
+sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a)
sparsePlusS _ ST t sp1 SpAbsent k =
- k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \_ -> ENothing ext (applySparse sp1 (fromSMTy t))) (\a _ -> EJust ext a)
+ k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a)
-- double sparse yields sparse
sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k =
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 3399de2..9a08457 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -404,7 +404,8 @@ subenvPlus :: SBool req1 -> SBool req2
-> (forall e. Ex e (Tup env1) -> Ex e (Tup env2) -> Ex e (Tup env3))
-> r)
-> r
-subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\_ _ -> ENil ext)
+-- don't destroy effects!
+subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\a b -> use a $ use b $ ENil ext)
subenvPlus req1 req2 (SCons _ env) (SENo sub1) (SENo sub2) k =
subenvPlus req1 req2 env sub1 sub2 $ \sub3 s31 s32 pl ->
diff --git a/test/Main.hs b/test/Main.hs
index 8da7598..5ec9dbc 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -451,6 +451,14 @@ term_build1_idx = fromNamed $ lambda @(TVec _) #x $ body $
idx0 $ sum1i $
build1 (#n `idiv` 2) (#i :-> #x ! pair nil (2 * #i))
+term_idx_coprod :: Ex '[TVec (TEither R R)] R
+term_idx_coprod = fromNamed $ lambda @(TVec (TEither R R)) #x $ body $
+ let_ #n (snd_ (shape #x)) $
+ idx0 $ sum1i $ build1 #n $ #i :->
+ case_ (#x ! pair nil #i)
+ (#a :-> #a * 2)
+ (#b :-> #b * 3)
+
term_pairs :: Ex [R, R] R
term_pairs = fromNamed $ lambda #x $ lambda #y $ body $
let_ #p (pair #x #y) $
@@ -583,12 +591,7 @@ tests_AD = testGroup "AD"
let_ #p (#x ! pair nil #i) $
3 * fst_ #p + 2 * snd_ #p
- ,adTest "idx-coprod" $ fromNamed $ lambda @(TVec (TEither R R)) #x $ body $
- let_ #n (snd_ (shape #x)) $
- idx0 $ sum1i $ build1 #n $ #i :->
- case_ (#x ! pair nil #i)
- (#a :-> #a * 2)
- (#b :-> #b * 3)
+ ,adTest "idx-coprod" $ term_idx_coprod
,adTest "idx-coprod-pair" $ fromNamed $ lambda @(TVec R) #arr $ body $
let_ #n (snd_ (shape #arr)) $