aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-13 21:38:52 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-13 21:38:52 +0100
commit239fc967b4dfd0eba6ca5b0a9e0d9a2d29e6ad5e (patch)
tree80efa7b82ac62b4847f1e56b5a8604e4ab03e38b /src/CHAD
parentdc61318a22e3492774ab6f6345c9a369222ef2f6 (diff)
Sparse: Maybe prevent another SpSparse introduction
Diffstat (limited to 'src/CHAD')
-rw-r--r--src/CHAD/AST/Sparse.hs31
1 files changed, 21 insertions, 10 deletions
diff --git a/src/CHAD/AST/Sparse.hs b/src/CHAD/AST/Sparse.hs
index 1cd5031..85f2882 100644
--- a/src/CHAD/AST/Sparse.hs
+++ b/src/CHAD/AST/Sparse.hs
@@ -201,16 +201,27 @@ sparsePlusS SF _ t (SpSparse sp1) sp2 k =
emaybe (weakenExpr WSink a)
(inj2 (evar IZ))
(plus (evar IZ) (evar (IS IZ))))
-sparsePlusS ST _ t (SpSparse sp1) sp2 k =
- sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
- k (SpSparse sp3)
- (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
- (Inj $ \b -> EJust ext (inj2 b))
- (\a b ->
- elet b $
- emaybe (weakenExpr WSink a)
- (EJust ext (inj2 (evar IZ)))
- (EJust ext (plus (evar IZ) (evar (IS IZ)))))
+sparsePlusS ST _ t (SpSparse sp1) sp2 k
+ | Just zero2 <- cheapZero (applySparse sp2 t) =
+ sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
+ k sp3
+ (Inj $ \a -> emaybe a (inj2 zero2) (inj1 (evar IZ)))
+ (Inj inj2)
+ (\a b ->
+ elet b $
+ emaybe (weakenExpr WSink a)
+ (inj2 (evar IZ))
+ (plus (evar IZ) (evar (IS IZ))))
+ | otherwise =
+ sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
+ k (SpSparse sp3)
+ (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
+ (Inj $ \b -> EJust ext (inj2 b))
+ (\a b ->
+ elet b $
+ emaybe (weakenExpr WSink a)
+ (EJust ext (inj2 (evar IZ)))
+ (EJust ext (plus (evar IZ) (evar (IS IZ)))))
sparsePlusS req1 req2 t sp1 (SpSparse sp2) k =
sparsePlusS req2 req1 t (SpSparse sp2) sp1 $ \sp3 inj1 inj2 plus ->
k sp3 inj2 inj1 (flip plus)