summaryrefslogtreecommitdiff
path: root/src/AST/Sparse.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST/Sparse.hs')
-rw-r--r--src/AST/Sparse.hs35
1 files changed, 30 insertions, 5 deletions
diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs
index 0c5bdb0..34a398f 100644
--- a/src/AST/Sparse.hs
+++ b/src/AST/Sparse.hs
@@ -1,8 +1,9 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImpredicativeTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-}
module AST.Sparse (module AST.Sparse, module AST.Sparse.Types) where
@@ -46,6 +47,24 @@ sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS I
sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2
+cheapZero :: SMTy t -> Maybe (forall env. Ex env t)
+cheapZero SMTNil = Just (ENil ext)
+cheapZero (SMTPair t1 t2)
+ | Just e1 <- cheapZero t1
+ , Just e2 <- cheapZero t2
+ = Just (EPair ext e1 e2)
+ | otherwise
+ = Nothing
+cheapZero (SMTLEither t1 t2) = Just (ELNil ext (fromSMTy t1) (fromSMTy t2))
+cheapZero (SMTMaybe t) = Just (ENothing ext (fromSMTy t))
+cheapZero SMTArr{} = Nothing
+cheapZero (SMTScal t) = case t of
+ STI32 -> Just (EConst ext t 0)
+ STI64 -> Just (EConst ext t 0)
+ STF32 -> Just (EConst ext t 0.0)
+ STF64 -> Just (EConst ext t 0.0)
+
+
data Injection sp a b where
-- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that
-- 'sparsePlusS' can provide injections even if the caller doesn't require
@@ -149,12 +168,18 @@ sparsePlusS _ _ t sp1 sp2 k
-- handle absents
sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b)
-sparsePlusS ST _ t SpAbsent sp2 k =
- k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b)
+sparsePlusS ST _ t SpAbsent sp2 k
+ | Just zero2 <- cheapZero (applySparse sp2 t) =
+ k sp2 (Inj $ \a -> use a $ zero2) (Inj id) (\a b -> use a $ b)
+ | otherwise =
+ k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b)
sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a)
-sparsePlusS _ ST t sp1 SpAbsent k =
- k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a)
+sparsePlusS _ ST t sp1 SpAbsent k
+ | Just zero1 <- cheapZero (applySparse sp1 t) =
+ k sp1 (Inj id) (Inj $ \b -> use b $ zero1) (\a b -> use b $ a)
+ | otherwise =
+ k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a)
-- double sparse yields sparse
sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k =