summaryrefslogtreecommitdiff
path: root/src/AST/Sparse.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST/Sparse.hs')
-rw-r--r--src/AST/Sparse.hs190
1 files changed, 74 insertions, 116 deletions
diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs
index ddae7fe..93258b7 100644
--- a/src/AST/Sparse.hs
+++ b/src/AST/Sparse.hs
@@ -1,126 +1,74 @@
+{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
-{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE ImpredicativeTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-}
-module AST.Sparse where
+module AST.Sparse (module AST.Sparse, module AST.Sparse.Types) where
-import Data.Kind (Constraint, Type)
import Data.Type.Equality
import AST
+import AST.Sparse.Types
+import Data (SBool(..))
-data Sparse t t' where
- SpSparse :: Sparse t t' -> Sparse t (TMaybe t')
- SpAbsent :: Sparse t TNil
-
- SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b')
- SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b')
- SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t')
- SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t')
- SpScal :: Sparse (TScal t) (TScal t)
-deriving instance Show (Sparse t t')
-
-class ApplySparse f where
- applySparse :: Sparse t t' -> f t -> f t'
-
-instance ApplySparse STy where
- applySparse (SpSparse s) t = STMaybe (applySparse s t)
- applySparse SpAbsent _ = STNil
- applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2)
- applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2)
- applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t)
- applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t)
- applySparse SpScal t = t
-
-instance ApplySparse SMTy where
- applySparse (SpSparse s) t = SMTMaybe (applySparse s t)
- applySparse SpAbsent _ = SMTNil
- applySparse (SpPair s1 s2) (SMTPair t1 t2) = SMTPair (applySparse s1 t1) (applySparse s2 t2)
- applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2)
- applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t)
- applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t)
- applySparse SpScal t = t
-
+sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t'
+sparsePlus _ SpAbsent e1 e2 = use e1 $ use e2 $ ENil ext
+sparsePlus t sp e1 e2 | Just Refl <- isDense t sp = EPlus ext t e1 e2
+sparsePlus t (SpSparse sp) e1 e2 = sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 -- heh
+sparsePlus (SMTPair t1 t2) (SpPair sp1 sp2) e1 e2 =
+ eunPair e1 $ \w1 e1a e1b ->
+ eunPair (weakenExpr w1 e2) $ \w2 e2a e2b ->
+ EPair ext (sparsePlus t1 sp1 (weakenExpr w2 e1a) e2a)
+ (sparsePlus t2 sp2 (weakenExpr w2 e1b) e2b)
+sparsePlus (SMTLEither t1 t2) (SpLEither sp1 sp2) e1 e2 =
+ elet e2 $
+ elcase (weakenExpr WSink e1)
+ (evar IZ)
+ (elcase (evar (IS IZ))
+ (ELInl ext (applySparse sp2 (fromSMTy t2)) (evar IZ))
+ (ELInl ext (applySparse sp2 (fromSMTy t2)) (sparsePlus t1 sp1 (evar (IS IZ)) (evar IZ)))
+ (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus ll+lr"))
+ (elcase (evar (IS IZ))
+ (ELInr ext (applySparse sp1 (fromSMTy t1)) (evar IZ))
+ (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus lr+ll")
+ (ELInr ext (applySparse sp1 (fromSMTy t1)) (sparsePlus t2 sp2 (evar (IS IZ)) (evar IZ))))
+sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 =
+ elet e2 $
+ emaybe (weakenExpr WSink e1)
+ (evar IZ)
+ (emaybe (evar (IS IZ))
+ (EJust ext (evar IZ))
+ (EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ))))
+sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2
+sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2
-class IsSubType s where
- type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint
- subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t'
- subtTrans :: s a b -> s b c -> s a c
- subtFull :: IsSubTypeSubject s f => f t -> s t t
-instance IsSubType (:~:) where
- type IsSubTypeSubject (:~:) f = ()
- subtApply = gcastWith
- subtTrans = trans
- subtFull _ = Refl
+cheapZero :: SMTy t -> Maybe (forall env. Ex env t)
+cheapZero SMTNil = Just (ENil ext)
+cheapZero (SMTPair t1 t2)
+ | Just e1 <- cheapZero t1
+ , Just e2 <- cheapZero t2
+ = Just (EPair ext e1 e2)
+ | otherwise
+ = Nothing
+cheapZero (SMTLEither t1 t2) = Just (ELNil ext (fromSMTy t1) (fromSMTy t2))
+cheapZero (SMTMaybe t) = Just (ENothing ext (fromSMTy t))
+cheapZero SMTArr{} = Nothing
+cheapZero (SMTScal t) = case t of
+ STI32 -> Just (EConst ext t 0)
+ STI64 -> Just (EConst ext t 0)
+ STF32 -> Just (EConst ext t 0.0)
+ STF64 -> Just (EConst ext t 0.0)
-instance IsSubType Sparse where
- type IsSubTypeSubject Sparse f = f ~ SMTy
- subtApply = applySparse
-
- subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2)
- subtTrans _ SpAbsent = SpAbsent
- subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b)
- subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b)
- subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2)
- subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2)
- subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2)
- subtTrans SpScal SpScal = SpScal
-
- subtFull = spDense
-
-spDense :: SMTy t -> Sparse t t
-spDense SMTNil = SpAbsent
-spDense (SMTPair t1 t2) = SpPair (spDense t1) (spDense t2)
-spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2)
-spDense (SMTMaybe t) = SpMaybe (spDense t)
-spDense (SMTArr _ t) = SpArr (spDense t)
-spDense (SMTScal _) = SpScal
-
-isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t')
-isDense SMTNil SpAbsent = Just Refl
-isDense _ SpSparse{} = Nothing
-isDense _ SpAbsent = Nothing
-isDense (SMTPair t1 t2) (SpPair s1 s2)
- | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl
- | otherwise = Nothing
-isDense (SMTLEither t1 t2) (SpLEither s1 s2)
- | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl
- | otherwise = Nothing
-isDense (SMTMaybe t) (SpMaybe s)
- | Just Refl <- isDense t s = Just Refl
- | otherwise = Nothing
-isDense (SMTArr _ t) (SpArr s)
- | Just Refl <- isDense t s = Just Refl
- | otherwise = Nothing
-isDense (SMTScal _) SpScal = Just Refl
-
-isAbsent :: Sparse t t' -> Bool
-isAbsent (SpSparse s) = isAbsent s
-isAbsent SpAbsent = True
-isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2
-isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2
-isAbsent (SpMaybe s) = isAbsent s
-isAbsent (SpArr s) = isAbsent s
-isAbsent SpScal = False
-
-
-data SBool b where
- SF :: SBool False
- ST :: SBool True
-deriving instance Show (SBool b)
data Injection sp a b where
-- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that
-- 'sparsePlusS' can provide injections even if the caller doesn't require
- -- them. This eliminates pointless checks.
+ -- them. This simplifies the sparsePlusS code.
Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b
Noinj :: Injection False a b
@@ -137,8 +85,11 @@ withInj2 (Inj f) (Inj g) k = Inj (k f g)
withInj2 Noinj _ _ = Noinj
withInj2 _ Noinj _ = Noinj
+use :: Ex env a -> Ex env b -> Ex env b
+use a b = elet a $ weakenExpr WSink b
+
-- | This function produces quadratically-sized code in the presence of nested
--- dynamic sparsity. しょうがない。
+-- dynamic sparsity. TODO can this be improved?
sparsePlusS
:: SBool inj1 -> SBool inj2
-> SMTy t -> Sparse t t1 -> Sparse t t2
@@ -148,16 +99,17 @@ sparsePlusS
-> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3)
-> r)
-> r
--- nil override
-sparsePlusS _ _ SMTNil _ _ k = k SpAbsent (Inj $ \_ -> ENil ext) (Inj $ \_ -> ENil ext) (\_ _ -> ENil ext)
+-- nil override (but don't destroy effects!)
+sparsePlusS _ _ SMTNil _ _ k =
+ k SpAbsent (Inj $ \a -> use a $ ENil ext) (Inj $ \b -> use b $ ENil ext) (\a b -> use a $ use b $ ENil ext)
-- simplifications
sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k =
sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus ->
- k sp3 (withInj minj1 $ \inj1 -> \_ -> inj1 (ENil ext)) minj2 (\_ b -> plus (ENil ext) b)
+ k sp3 (withInj minj1 $ \inj1 -> \a -> use a $ inj1 (ENil ext)) minj2 (\a b -> use a $ plus (ENil ext) b)
sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k =
sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus ->
- k sp3 minj1 (withInj minj2 $ \inj2 -> \_ -> inj2 (ENil ext)) (\a _ -> plus a (ENil ext))
+ k sp3 minj1 (withInj minj2 $ \inj2 -> \b -> use b $ inj2 (ENil ext)) (\a b -> use b $ plus a (ENil ext))
sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k =
let ta = applySparse sp1 (fromSMTy t) in
@@ -215,13 +167,19 @@ sparsePlusS _ _ t sp1 sp2 k
= k (spDense t) (Inj id) (Inj id) (\a b -> EPlus ext t a b)
-- handle absents
-sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\_ b -> b)
-sparsePlusS ST _ t SpAbsent sp2 k =
- k (SpSparse sp2) (Inj $ \_ -> ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\_ b -> EJust ext b)
+sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b)
+sparsePlusS ST _ t SpAbsent sp2 k
+ | Just zero2 <- cheapZero (applySparse sp2 t) =
+ k sp2 (Inj $ \a -> use a $ zero2) (Inj id) (\a b -> use a $ b)
+ | otherwise =
+ k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b)
-sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a _ -> a)
-sparsePlusS _ ST t sp1 SpAbsent k =
- k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \_ -> ENothing ext (applySparse sp1 (fromSMTy t))) (\a _ -> EJust ext a)
+sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a)
+sparsePlusS _ ST t sp1 SpAbsent k
+ | Just zero1 <- cheapZero (applySparse sp1 t) =
+ k sp1 (Inj id) (Inj $ \b -> use b $ zero1) (\a b -> use b $ a)
+ | otherwise =
+ k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a)
-- double sparse yields sparse
sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k =