diff options
Diffstat (limited to 'src/AST/Sparse.hs')
-rw-r--r-- | src/AST/Sparse.hs | 35 |
1 files changed, 30 insertions, 5 deletions
diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs index 0c5bdb0..34a398f 100644 --- a/src/AST/Sparse.hs +++ b/src/AST/Sparse.hs @@ -1,8 +1,9 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImpredicativeTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} {-# OPTIONS_GHC -fmax-pmcheck-models=80 #-} module AST.Sparse (module AST.Sparse, module AST.Sparse.Types) where @@ -46,6 +47,24 @@ sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS I sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2 +cheapZero :: SMTy t -> Maybe (forall env. Ex env t) +cheapZero SMTNil = Just (ENil ext) +cheapZero (SMTPair t1 t2) + | Just e1 <- cheapZero t1 + , Just e2 <- cheapZero t2 + = Just (EPair ext e1 e2) + | otherwise + = Nothing +cheapZero (SMTLEither t1 t2) = Just (ELNil ext (fromSMTy t1) (fromSMTy t2)) +cheapZero (SMTMaybe t) = Just (ENothing ext (fromSMTy t)) +cheapZero SMTArr{} = Nothing +cheapZero (SMTScal t) = case t of + STI32 -> Just (EConst ext t 0) + STI64 -> Just (EConst ext t 0) + STF32 -> Just (EConst ext t 0.0) + STF64 -> Just (EConst ext t 0.0) + + data Injection sp a b where -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that -- 'sparsePlusS' can provide injections even if the caller doesn't require @@ -149,12 +168,18 @@ sparsePlusS _ _ t sp1 sp2 k -- handle absents sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b) -sparsePlusS ST _ t SpAbsent sp2 k = - k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b) +sparsePlusS ST _ t SpAbsent sp2 k + | Just zero2 <- cheapZero (applySparse sp2 t) = + k sp2 (Inj $ \a -> use a $ zero2) (Inj id) (\a b -> use a $ b) + | otherwise = + k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b) sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a) -sparsePlusS _ ST t sp1 SpAbsent k = - k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a) +sparsePlusS _ ST t sp1 SpAbsent k + | Just zero1 <- cheapZero (applySparse sp1 t) = + k sp1 (Inj id) (Inj $ \b -> use b $ zero1) (\a b -> use b $ a) + | otherwise = + k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a) -- double sparse yields sparse sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k = |