diff options
Diffstat (limited to 'src/CHAD/AST')
| -rw-r--r-- | src/CHAD/AST/Env.hs | 4 | ||||
| -rw-r--r-- | src/CHAD/AST/Sparse/Types.hs | 33 | ||||
| -rw-r--r-- | src/CHAD/AST/Types.hs | 59 | ||||
| -rw-r--r-- | src/CHAD/AST/UnMonoid.hs | 5 |
4 files changed, 87 insertions, 14 deletions
diff --git a/src/CHAD/AST/Env.hs b/src/CHAD/AST/Env.hs index 8e6b745..40b6ca2 100644 --- a/src/CHAD/AST/Env.hs +++ b/src/CHAD/AST/Env.hs @@ -53,9 +53,9 @@ subenvOnehot (SCons _ env) IZ sp = SEYes sp (subenvNone env) subenvOnehot (SCons _ env) (IS i) sp = SENo (subenvOnehot env i sp) subenvOnehot SNil i _ = case i of {} -subenvCompose :: IsSubType s => Subenv' s env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3 +subenvCompose :: IsSubType s => Subenv' (:~:) env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3 subenvCompose SETop SETop = SETop -subenvCompose (SEYes s1 sub1) (SEYes s2 sub2) = SEYes (subtTrans s1 s2) (subenvCompose sub1 sub2) +subenvCompose (SEYes Refl sub1) (SEYes s2 sub2) = SEYes s2 (subenvCompose sub1 sub2) subenvCompose (SEYes _ sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2) subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2) diff --git a/src/CHAD/AST/Sparse/Types.hs b/src/CHAD/AST/Sparse/Types.hs index 8f41ba4..930475a 100644 --- a/src/CHAD/AST/Sparse/Types.hs +++ b/src/CHAD/AST/Sparse/Types.hs @@ -1,14 +1,18 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} module CHAD.AST.Sparse.Types where import Data.Kind (Type, Constraint) import Data.Type.Equality import CHAD.AST.Types +import CHAD.Data data Sparse t t' where @@ -19,9 +23,22 @@ data Sparse t t' where SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b') SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t') SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') + SpArrIdx :: SList (Sparse t) t's -> Sparse (TArr n t) (MultiHot n t's) SpScal :: Sparse (TScal t) (TScal t) deriving instance Show (Sparse t t') +type family MultiHot n t's where + MultiHot n '[] = TNil + MultiHot n (t' : t's) = TPair (TPair (Tup (Replicate n TIx)) t') (MultiHot n t's) + +tMultiHot :: SNat n -> SList STy ts -> STy (MultiHot n ts) +tMultiHot _ SNil = STNil +tMultiHot n (t `SCons` ts) = STPair (STPair (tTup (sreplicate n tIx)) t) (tMultiHot n ts) + +mtMultiHot :: SNat n -> SList SMTy ts -> SMTy (MultiHot n ts) +mtMultiHot _ SNil = SMTNil +mtMultiHot n (t `SCons` ts) = SMTPair (SMTPair (tTup (sreplicate n tIx)) t) (tMultiHot n ts) + class ApplySparse f where applySparse :: Sparse t t' -> f t -> f t' @@ -32,6 +49,7 @@ instance ApplySparse STy where applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2) applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t) applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) + applySparse (SpArrIdx ss) (STArr n t) = tMultiHot n (slistMap (`applySparse` t) ss) applySparse SpScal t = t instance ApplySparse SMTy where @@ -41,34 +59,23 @@ instance ApplySparse SMTy where applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2) applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t) applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t) + applySparse (SpArrIdx s) (SMTArr n t) = SMTPair (mkTup SMTNil SMTPair (sreplicate n (knownMTy @TIx))) (applySparse s t) applySparse SpScal t = t class IsSubType s where type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t' - subtTrans :: s a b -> s b c -> s a c subtFull :: IsSubTypeSubject s f => f t -> s t t instance IsSubType (:~:) where type IsSubTypeSubject (:~:) f = () subtApply = gcastWith - subtTrans = trans subtFull _ = Refl instance IsSubType Sparse where type IsSubTypeSubject Sparse f = f ~ SMTy subtApply = applySparse - - subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2) - subtTrans _ SpAbsent = SpAbsent - subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b) - subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b) - subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2) - subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2) - subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2) - subtTrans SpScal SpScal = SpScal - subtFull = spDense spDense :: SMTy t -> Sparse t t @@ -95,6 +102,7 @@ isDense (SMTMaybe t) (SpMaybe s) isDense (SMTArr _ t) (SpArr s) | Just Refl <- isDense t s = Just Refl | otherwise = Nothing +isDense SMTArr{} SpArrIdx{} = Nothing isDense (SMTScal _) SpScal = Just Refl isAbsent :: Sparse t t' -> Bool @@ -104,4 +112,5 @@ isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2 isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2 isAbsent (SpMaybe s) = isAbsent s isAbsent (SpArr s) = isAbsent s +isAbsent (SpArrIdx s) = isAbsent s isAbsent SpScal = False diff --git a/src/CHAD/AST/Types.hs b/src/CHAD/AST/Types.hs index f0feb55..bec2201 100644 --- a/src/CHAD/AST/Types.hs +++ b/src/CHAD/AST/Types.hs @@ -77,6 +77,7 @@ data SMTy t where SMTMaybe :: SMTy a -> SMTy (TMaybe a) SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t) SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t) + SMTData :: STy a -> SMTy a -- ^ inclusion of non-monoidal information deriving instance Show (SMTy t) instance GCompare SMTy where @@ -215,3 +216,61 @@ unTup unpack (_ `SCons` list) tup = type family InvTup core env where InvTup core '[] = core InvTup core (t : ts) = InvTup (TPair core t) ts + +class KnownScalTy t where knownScalTy :: SScalTy t +instance KnownScalTy TI32 where knownScalTy = STI32 +instance KnownScalTy TI64 where knownScalTy = STI64 +instance KnownScalTy TF32 where knownScalTy = STF32 +instance KnownScalTy TF64 where knownScalTy = STF64 +instance KnownScalTy TBool where knownScalTy = STBool + +class KnownTy t where knownTy :: STy t +instance KnownTy TNil where knownTy = STNil +instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy +instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy +instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy +instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy +instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy +instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy +instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy + +class KnownMTy t where knownMTy :: SMTy t +instance KnownMTy TNil where knownMTy = SMTNil +instance (KnownMTy s, KnownMTy t) => KnownMTy (TPair s t) where knownMTy = SMTPair knownMTy knownMTy +instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy +instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy +instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy +instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy + +class KnownEnv env where knownEnv :: SList STy env +instance KnownEnv '[] where knownEnv = SNil +instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv + +styKnown :: STy t -> Dict (KnownTy t) +styKnown STNil = Dict +styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict +styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict +styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict +styKnown (STMaybe t) | Dict <- styKnown t = Dict +styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict +styKnown (STScal t) | Dict <- sscaltyKnown t = Dict +styKnown (STAccum t) | Dict <- smtyKnown t = Dict + +smtyKnown :: SMTy t -> Dict (KnownMTy t) +smtyKnown SMTNil = Dict +smtyKnown (SMTPair a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict +smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict +smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict +smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict +smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict + +sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t) +sscaltyKnown STI32 = Dict +sscaltyKnown STI64 = Dict +sscaltyKnown STF32 = Dict +sscaltyKnown STF64 = Dict +sscaltyKnown STBool = Dict + +envKnown :: SList STy env -> Dict (KnownEnv env) +envKnown SNil = Dict +envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict diff --git a/src/CHAD/AST/UnMonoid.hs b/src/CHAD/AST/UnMonoid.hs index d3cad25..1a66cdf 100644 --- a/src/CHAD/AST/UnMonoid.hs +++ b/src/CHAD/AST/UnMonoid.hs @@ -212,6 +212,11 @@ accumulateSparse topty topsp arg accum = case (topty, topsp) of (EIdx ext (evar (IS IZ)) (EVar ext tn IZ)) (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $ ENil ext + (SMTArr _ t, SpArrIdx s) -> + eunPair arg $ \w1 e1 e2 -> + elet (accumulateSparse t s e2 + (\w prj idx val -> accum (w .> w1) (SAPArrIdx prj) (EPair ext (weakenExpr w e1) idx) val)) $ + ENil ext acPrjCompose :: SAIDense dense |
