summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-06-18 10:29:16 +0200
committerTom Smeding <tom@tomsmeding.com>2025-06-18 10:29:16 +0200
commit5a0a5e9ef69926265289ae5229e68060a7c77a27 (patch)
treee2b39aea96f5bb0baadf187506a44d47ab1dfdd6
parentfe80b31555c27f038b20eb84eb1e747781d7c76b (diff)
Don't introduce sparsity if zero is cheap
-rw-r--r--src/AST/Sparse.hs35
-rw-r--r--src/CHAD.hs37
2 files changed, 55 insertions, 17 deletions
diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs
index 0c5bdb0..34a398f 100644
--- a/src/AST/Sparse.hs
+++ b/src/AST/Sparse.hs
@@ -1,8 +1,9 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImpredicativeTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-}
module AST.Sparse (module AST.Sparse, module AST.Sparse.Types) where
@@ -46,6 +47,24 @@ sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS I
sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2
+cheapZero :: SMTy t -> Maybe (forall env. Ex env t)
+cheapZero SMTNil = Just (ENil ext)
+cheapZero (SMTPair t1 t2)
+ | Just e1 <- cheapZero t1
+ , Just e2 <- cheapZero t2
+ = Just (EPair ext e1 e2)
+ | otherwise
+ = Nothing
+cheapZero (SMTLEither t1 t2) = Just (ELNil ext (fromSMTy t1) (fromSMTy t2))
+cheapZero (SMTMaybe t) = Just (ENothing ext (fromSMTy t))
+cheapZero SMTArr{} = Nothing
+cheapZero (SMTScal t) = case t of
+ STI32 -> Just (EConst ext t 0)
+ STI64 -> Just (EConst ext t 0)
+ STF32 -> Just (EConst ext t 0.0)
+ STF64 -> Just (EConst ext t 0.0)
+
+
data Injection sp a b where
-- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that
-- 'sparsePlusS' can provide injections even if the caller doesn't require
@@ -149,12 +168,18 @@ sparsePlusS _ _ t sp1 sp2 k
-- handle absents
sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b)
-sparsePlusS ST _ t SpAbsent sp2 k =
- k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b)
+sparsePlusS ST _ t SpAbsent sp2 k
+ | Just zero2 <- cheapZero (applySparse sp2 t) =
+ k sp2 (Inj $ \a -> use a $ zero2) (Inj id) (\a b -> use a $ b)
+ | otherwise =
+ k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b)
sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a)
-sparsePlusS _ ST t sp1 SpAbsent k =
- k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a)
+sparsePlusS _ ST t sp1 SpAbsent k
+ | Just zero1 <- cheapZero (applySparse sp1 t) =
+ k sp1 (Inj id) (Inj $ \b -> use b $ zero1) (\a b -> use b $ a)
+ | otherwise =
+ k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a)
-- double sparse yields sparse
sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k =
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 9a08457..143376a 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -423,18 +423,31 @@ subenvPlus req1 SF (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k =
EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
(weakenExpr WSink e2))
(ESnd ext (EVar ext (typeOf e1) IZ)))
-subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k =
- subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl ->
- k (SEYes (SpSparse sp1) sub3)
- (withInj minj13 $ \inj13 ->
- \e1 -> eunPair e1 $ \_ e1a e1b ->
- EPair ext (inj13 e1a) (EJust ext e1b))
- (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t))))
- (\e1 e2 ->
- ELet ext e1 $
- EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
- (weakenExpr WSink e2))
- (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ))))
+subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k
+ | Just zero1 <- cheapZero (applySparse sp1 t) =
+ subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl ->
+ k (SEYes sp1 sub3)
+ (withInj minj13 $ \inj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) e1b)
+ (Inj $ \e2 -> EPair ext (inj23 e2) zero1)
+ (\e1 e2 ->
+ ELet ext e1 $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
+ (weakenExpr WSink e2))
+ (ESnd ext (EVar ext (typeOf e1) IZ)))
+ | otherwise =
+ subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl ->
+ k (SEYes (SpSparse sp1) sub3)
+ (withInj minj13 $ \inj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) (EJust ext e1b))
+ (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t))))
+ (\e1 e2 ->
+ ELet ext e1 $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
+ (weakenExpr WSink e2))
+ (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ))))
subenvPlus req1 req2 (SCons t env) sub1@SENo{} sub2@SEYes{} k =
subenvPlus req2 req1 (SCons t env) sub2 sub1 $ \sub3 minj23 minj13 pl ->