summaryrefslogtreecommitdiff
path: root/src/AST/Sparse.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST/Sparse.hs')
-rw-r--r--src/AST/Sparse.hs244
1 files changed, 71 insertions, 173 deletions
diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs
index 09dbc70..ddae7fe 100644
--- a/src/AST/Sparse.hs
+++ b/src/AST/Sparse.hs
@@ -7,7 +7,7 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
-{-# OPTIONS_GHC -fmax-pmcheck-models=60 #-}
+{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-}
module AST.Sparse where
import Data.Kind (Constraint, Type)
@@ -17,66 +17,99 @@ import AST
data Sparse t t' where
- SpDense :: Sparse t t
SpSparse :: Sparse t t' -> Sparse t (TMaybe t')
SpAbsent :: Sparse t TNil
SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b')
SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b')
- SpLeft :: Sparse a a' -> Sparse (TLEither a b) a'
- SpRight :: Sparse b b' -> Sparse (TLEither a b) b'
SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t')
- SpJust :: Sparse t t' -> Sparse (TMaybe t) t'
SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t')
+ SpScal :: Sparse (TScal t) (TScal t)
deriving instance Show (Sparse t t')
-applySparse :: Sparse t t' -> STy t -> STy t'
-applySparse SpDense t = t
-applySparse (SpSparse s) t = STMaybe (applySparse s t)
-applySparse SpAbsent _ = STNil
-applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2)
-applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2)
-applySparse (SpLeft s) (STLEither t1 _) = applySparse s t1
-applySparse (SpRight s) (STLEither _ t2) = applySparse s t2
-applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t)
-applySparse (SpJust s) (STMaybe t) = applySparse s t
-applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t)
+class ApplySparse f where
+ applySparse :: Sparse t t' -> f t -> f t'
+
+instance ApplySparse STy where
+ applySparse (SpSparse s) t = STMaybe (applySparse s t)
+ applySparse SpAbsent _ = STNil
+ applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2)
+ applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2)
+ applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t)
+ applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t)
+ applySparse SpScal t = t
+
+instance ApplySparse SMTy where
+ applySparse (SpSparse s) t = SMTMaybe (applySparse s t)
+ applySparse SpAbsent _ = SMTNil
+ applySparse (SpPair s1 s2) (SMTPair t1 t2) = SMTPair (applySparse s1 t1) (applySparse s2 t2)
+ applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2)
+ applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t)
+ applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t)
+ applySparse SpScal t = t
class IsSubType s where
- type IsSubTypeSubject s (f :: k -> Type) :: Constraint
+ type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint
subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t'
subtTrans :: s a b -> s b c -> s a c
- subtFull :: s a a
+ subtFull :: IsSubTypeSubject s f => f t -> s t t
instance IsSubType (:~:) where
type IsSubTypeSubject (:~:) f = ()
subtApply = gcastWith
subtTrans = trans
- subtFull = Refl
+ subtFull _ = Refl
instance IsSubType Sparse where
- type IsSubTypeSubject Sparse f = f ~ STy
+ type IsSubTypeSubject Sparse f = f ~ SMTy
subtApply = applySparse
- subtTrans SpDense s = s
- subtTrans s SpDense = s
subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2)
subtTrans _ SpAbsent = SpAbsent
subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b)
subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b)
- subtTrans (SpLEither s1 _) (SpLeft s2) = SpLeft (subtTrans s1 s2)
- subtTrans (SpLEither _ s1) (SpRight s2) = SpRight (subtTrans s1 s2)
- subtTrans (SpLeft s1) s2 = SpLeft (subtTrans s1 s2)
- subtTrans (SpRight s1) s2 = SpRight (subtTrans s1 s2)
subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2)
- subtTrans (SpSparse s1) (SpJust s2) = subtTrans s1 s2
subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2)
- subtTrans (SpMaybe s1) (SpJust s2) = SpJust (subtTrans s1 s2)
- subtTrans (SpJust s1) s2 = SpJust (subtTrans s1 s2)
subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2)
-
- subtFull = SpDense
+ subtTrans SpScal SpScal = SpScal
+
+ subtFull = spDense
+
+spDense :: SMTy t -> Sparse t t
+spDense SMTNil = SpAbsent
+spDense (SMTPair t1 t2) = SpPair (spDense t1) (spDense t2)
+spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2)
+spDense (SMTMaybe t) = SpMaybe (spDense t)
+spDense (SMTArr _ t) = SpArr (spDense t)
+spDense (SMTScal _) = SpScal
+
+isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t')
+isDense SMTNil SpAbsent = Just Refl
+isDense _ SpSparse{} = Nothing
+isDense _ SpAbsent = Nothing
+isDense (SMTPair t1 t2) (SpPair s1 s2)
+ | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl
+ | otherwise = Nothing
+isDense (SMTLEither t1 t2) (SpLEither s1 s2)
+ | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl
+ | otherwise = Nothing
+isDense (SMTMaybe t) (SpMaybe s)
+ | Just Refl <- isDense t s = Just Refl
+ | otherwise = Nothing
+isDense (SMTArr _ t) (SpArr s)
+ | Just Refl <- isDense t s = Just Refl
+ | otherwise = Nothing
+isDense (SMTScal _) SpScal = Just Refl
+
+isAbsent :: Sparse t t' -> Bool
+isAbsent (SpSparse s) = isAbsent s
+isAbsent SpAbsent = True
+isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2
+isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2
+isAbsent (SpMaybe s) = isAbsent s
+isAbsent (SpArr s) = isAbsent s
+isAbsent SpScal = False
data SBool b where
@@ -176,7 +209,10 @@ sparsePlusS req1 req2 t sp1 (SpMaybe (SpSparse sp2)) k = sparsePlusS req1 req2 t
-- TODO: sparse of Just is just Maybe
-- dense plus
-sparsePlusS _ _ t SpDense SpDense k = k SpDense (Inj id) (Inj id) (\a b -> EPlus ext t a b)
+sparsePlusS _ _ t sp1 sp2 k
+ | Just Refl <- isDense t sp1
+ , Just Refl <- isDense t sp2
+ = k (spDense t) (Inj id) (Inj id) (\a b -> EPlus ext t a b)
-- handle absents
sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\_ b -> b)
@@ -239,8 +275,6 @@ sparsePlusS req1 req2 (SMTPair ta tb) (SpPair sp1a sp1b) (SpPair sp2a sp2b) k =
eunPair x1 $ \w1 x1a x1b ->
eunPair (weakenExpr w1 x2) $ \w2 x2a x2b ->
EPair ext (plusa (weakenExpr w2 x1a) x2a) (plusb (weakenExpr w2 x1b) x2b))
-sparsePlusS req1 req2 t sp1@SpPair{} SpDense k = sparsePlusS req1 req2 t sp1 (SpPair SpDense SpDense) k
-sparsePlusS req1 req2 t SpDense sp2@SpPair{} k = sparsePlusS req1 req2 t (SpPair SpDense SpDense) sp2 k
-- coproducts
sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k =
@@ -268,107 +302,6 @@ sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k
(inr (inj13b (evar IZ)))
(EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS lr+ll")
(inr (plusb (evar (IS IZ)) (evar IZ)))))
-sparsePlusS req1 req2 t sp1@SpLEither{} SpDense k = sparsePlusS req1 req2 t sp1 (SpLEither SpDense SpDense) k
-sparsePlusS req1 req2 t SpDense sp2@SpLEither{} k = sparsePlusS req1 req2 t (SpLEither SpDense SpDense) sp2 k
-
--- coproducts with partially known arguments: if we have a non-nil
--- always-present coproduct argument, the result is dense, otherwise we
--- introduce sparsity
-sparsePlusS _ SF (SMTLEither ta _) (SpLeft sp1a) (SpLEither sp2a _) k =
- sparsePlusS ST SF ta sp1a sp2a $ \sp3a (Inj inj13a) _ plusa ->
- k (SpLeft sp3a)
- (Inj inj13a)
- Noinj
- (\x1 x2 ->
- elet x1 $
- elcase (weakenExpr WSink x2)
- (inj13a (evar IZ))
- (plusa (evar (IS IZ)) (evar IZ))
- (EError ext (applySparse sp3a (fromSMTy ta)) "plusS !ll+lr"))
-
-sparsePlusS _ ST (SMTLEither ta _) (SpLeft sp1a) (SpLEither sp2a _) k =
- sparsePlusS ST ST ta sp1a sp2a $ \sp3a (Inj inj13a) (Inj inj23a) plusa ->
- k (SpSparse (SpLeft sp3a))
- (Inj $ \x1 -> EJust ext (inj13a x1))
- (Inj $ \x2 ->
- elcase x2
- (ENothing ext (applySparse sp3a (fromSMTy ta)))
- (EJust ext (inj23a (evar IZ)))
- (EError ext (STMaybe (applySparse sp3a (fromSMTy ta))) "plusSi2 !ll+lr"))
- (\x1 x2 ->
- elet x1 $
- EJust ext $
- elcase (weakenExpr WSink x2)
- (inj13a (evar IZ))
- (plusa (evar (IS IZ)) (evar IZ))
- (EError ext (applySparse sp3a (fromSMTy ta)) "plusS !ll+lr"))
-
-sparsePlusS req1 req2 t sp1@SpLEither{} sp2@SpLeft{} k =
- sparsePlusS req2 req1 t sp2 sp1 $ \sp3a inj13a inj23a plusa -> k sp3a inj23a inj13a (flip plusa)
-sparsePlusS req1 req2 t sp1@SpLeft{} SpDense k = sparsePlusS req1 req2 t sp1 (SpLEither SpDense SpDense) k
-sparsePlusS req1 req2 t SpDense sp2@SpLeft{} k = sparsePlusS req1 req2 t (SpLEither SpDense SpDense) sp2 k
-
-sparsePlusS _ SF (SMTLEither _ tb) (SpRight sp1b) (SpLEither _ sp2b) k =
- sparsePlusS ST SF tb sp1b sp2b $ \sp3b (Inj inj13b) _ plusb ->
- k (SpRight sp3b)
- (Inj inj13b)
- Noinj
- (\x1 x2 ->
- elet x1 $
- elcase (weakenExpr WSink x2)
- (inj13b (evar IZ))
- (EError ext (applySparse sp3b (fromSMTy tb)) "plusS !lr+ll")
- (plusb (evar (IS IZ)) (evar IZ)))
-
-sparsePlusS _ ST (SMTLEither _ tb) (SpRight sp1b) (SpLEither _ sp2b) k =
- sparsePlusS ST ST tb sp1b sp2b $ \sp3b (Inj inj13b) (Inj inj23b) plusb ->
- k (SpSparse (SpRight sp3b))
- (Inj $ \x1 -> EJust ext (inj13b x1))
- (Inj $ \x2 ->
- elcase x2
- (ENothing ext (applySparse sp3b (fromSMTy tb)))
- (EError ext (STMaybe (applySparse sp3b (fromSMTy tb))) "plusSi2 !lr+ll")
- (EJust ext (inj23b (evar IZ))))
- (\x1 x2 ->
- elet x1 $
- EJust ext $
- elcase (weakenExpr WSink x2)
- (inj13b (evar IZ))
- (EError ext (applySparse sp3b (fromSMTy tb)) "plusS !lr+ll")
- (plusb (evar (IS IZ)) (evar IZ)))
-
-sparsePlusS req1 req2 t sp1@SpLEither{} sp2@SpRight{} k =
- sparsePlusS req2 req1 t sp2 sp1 $ \sp3b inj13b inj23b plusb -> k sp3b inj23b inj13b (flip plusb)
-sparsePlusS req1 req2 t sp1@SpRight{} SpDense k = sparsePlusS req1 req2 t sp1 (SpLEither SpDense SpDense) k
-sparsePlusS req1 req2 t SpDense sp2@SpRight{} k = sparsePlusS req1 req2 t (SpLEither SpDense SpDense) sp2 k
-
--- dense same-branch coproducts simply recurse
-sparsePlusS req1 req2 (SMTLEither ta _) (SpLeft sp1) (SpLeft sp2) k =
- sparsePlusS req1 req2 ta sp1 sp2 $ \sp3 inj1 inj2 plus ->
- k (SpLeft sp3) inj1 inj2 plus
-sparsePlusS req1 req2 (SMTLEither _ tb) (SpRight sp1) (SpRight sp2) k =
- sparsePlusS req1 req2 tb sp1 sp2 $ \sp3 inj1 inj2 plus ->
- k (SpRight sp3) inj1 inj2 plus
-
--- dense, mismatched coproducts are valid as long as we don't actually invoke
--- plus at runtime (injections are fine)
-sparsePlusS SF SF _ SpLeft{} SpRight{} k =
- k SpAbsent Noinj Noinj (\_ _ -> EError ext STNil "plusS !ll+!lr")
-sparsePlusS SF ST (SMTLEither _ tb) SpLeft{} (SpRight sp2) k =
- k (SpRight sp2) Noinj (Inj id)
- (\_ _ -> EError ext (applySparse sp2 (fromSMTy tb)) "plusS !ll+?lr")
-sparsePlusS ST SF (SMTLEither ta _) (SpLeft sp1) SpRight{} k =
- k (SpLeft sp1) (Inj id) Noinj
- (\_ _ -> EError ext (applySparse sp1 (fromSMTy ta)) "plusS !lr+?ll")
-sparsePlusS ST ST (SMTLEither ta tb) (SpLeft sp1) (SpRight sp2) k =
- -- note: we know that this cannot be ELNil, but the returned 'Sparse' unfortunately claims to allow it.
- k (SpLEither sp1 sp2)
- (Inj $ \a -> ELInl ext (applySparse sp2 (fromSMTy tb)) a)
- (Inj $ \b -> ELInr ext (applySparse sp1 (fromSMTy ta)) b)
- (\_ _ -> EError ext (STLEither (applySparse sp1 (fromSMTy ta)) (applySparse sp2 (fromSMTy tb))) "plusS ?ll+?lr")
-
-sparsePlusS req1 req2 t sp1@SpRight{} sp2@SpLeft{} k = -- the errors are not flipped, but eh
- sparsePlusS req2 req1 t sp2 sp1 $ \sp3 inj1 inj2 plus -> k sp3 inj2 inj1 (flip plus)
-- maybe
sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k =
@@ -385,42 +318,6 @@ sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k =
(emaybe (evar (IS IZ))
(EJust ext (inj1 (evar IZ)))
(EJust ext (plus (evar (IS IZ)) (evar IZ)))))
-sparsePlusS req1 req2 t sp1@SpMaybe{} SpDense k = sparsePlusS req1 req2 t sp1 (SpMaybe SpDense) k
-sparsePlusS req1 req2 t SpDense sp2@SpMaybe{} k = sparsePlusS req1 req2 t (SpMaybe SpDense) sp2 k
-
--- maybe with partially known arguments: if we have an always-present Just
--- argument, the result is dense, otherwise we introduce sparsity by weakening
--- to SpMaybe
-sparsePlusS _ SF (SMTMaybe t) (SpJust sp1) (SpMaybe sp2) k =
- sparsePlusS ST SF t sp1 sp2 $ \sp3 (Inj inj1) _ plus ->
- k (SpJust sp3)
- (Inj inj1)
- Noinj
- (\a b ->
- elet a $
- emaybe (weakenExpr WSink b)
- (inj1 (evar IZ))
- (plus (evar (IS IZ)) (evar IZ)))
-sparsePlusS _ ST (SMTMaybe t) (SpJust sp1) (SpMaybe sp2) k =
- sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
- k (SpMaybe sp3)
- (Inj $ \a -> EJust ext (inj1 a))
- (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ))))
- (\a b ->
- elet a $
- emaybe (weakenExpr WSink b)
- (EJust ext (inj1 (evar IZ)))
- (EJust ext (plus (evar (IS IZ)) (evar IZ))))
-
-sparsePlusS req1 req2 t sp1@SpMaybe{} sp2@SpJust{} k =
- sparsePlusS req2 req1 t sp2 sp1 $ \sp3 inj2 inj1 plus -> k sp3 inj1 inj2 (flip plus)
-sparsePlusS req1 req2 t sp1@SpJust{} SpDense k = sparsePlusS req1 req2 t sp1 (SpMaybe SpDense) k
-sparsePlusS req1 req2 t SpDense sp2@SpJust{} k = sparsePlusS req1 req2 t (SpMaybe SpDense) sp2 k
-
--- dense same-branch maybes simply recurse
-sparsePlusS req1 req2 (SMTMaybe t) (SpJust sp1) (SpJust sp2) k =
- sparsePlusS req1 req2 t sp1 sp2 $ \sp3 inj1 inj2 plus ->
- k (SpJust sp3) inj1 inj2 plus
-- dense array cotangents simply recurse
sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k =
@@ -430,5 +327,6 @@ sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k =
(withInj minj2 $ \inj2 -> emap (inj2 (EVar ext (applySparse sp2 (fromSMTy t)) IZ)))
(ezipWith (plus (EVar ext (applySparse sp1 (fromSMTy t)) (IS IZ))
(EVar ext (applySparse sp2 (fromSMTy t)) IZ)))
-sparsePlusS req1 req2 t (SpArr sp1) SpDense k = sparsePlusS req1 req2 t (SpArr sp1) (SpArr SpDense) k
-sparsePlusS req1 req2 t SpDense (SpArr sp2) k = sparsePlusS req1 req2 t (SpArr SpDense) (SpArr sp2) k
+
+-- scalars
+sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t))