aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/AST/Sparse
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD/AST/Sparse')
-rw-r--r--src/CHAD/AST/Sparse/Types.hs33
1 files changed, 21 insertions, 12 deletions
diff --git a/src/CHAD/AST/Sparse/Types.hs b/src/CHAD/AST/Sparse/Types.hs
index 8f41ba4..930475a 100644
--- a/src/CHAD/AST/Sparse/Types.hs
+++ b/src/CHAD/AST/Sparse/Types.hs
@@ -1,14 +1,18 @@
+{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
module CHAD.AST.Sparse.Types where
import Data.Kind (Type, Constraint)
import Data.Type.Equality
import CHAD.AST.Types
+import CHAD.Data
data Sparse t t' where
@@ -19,9 +23,22 @@ data Sparse t t' where
SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b')
SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t')
SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t')
+ SpArrIdx :: SList (Sparse t) t's -> Sparse (TArr n t) (MultiHot n t's)
SpScal :: Sparse (TScal t) (TScal t)
deriving instance Show (Sparse t t')
+type family MultiHot n t's where
+ MultiHot n '[] = TNil
+ MultiHot n (t' : t's) = TPair (TPair (Tup (Replicate n TIx)) t') (MultiHot n t's)
+
+tMultiHot :: SNat n -> SList STy ts -> STy (MultiHot n ts)
+tMultiHot _ SNil = STNil
+tMultiHot n (t `SCons` ts) = STPair (STPair (tTup (sreplicate n tIx)) t) (tMultiHot n ts)
+
+mtMultiHot :: SNat n -> SList SMTy ts -> SMTy (MultiHot n ts)
+mtMultiHot _ SNil = SMTNil
+mtMultiHot n (t `SCons` ts) = SMTPair (SMTPair (tTup (sreplicate n tIx)) t) (tMultiHot n ts)
+
class ApplySparse f where
applySparse :: Sparse t t' -> f t -> f t'
@@ -32,6 +49,7 @@ instance ApplySparse STy where
applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2)
applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t)
applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t)
+ applySparse (SpArrIdx ss) (STArr n t) = tMultiHot n (slistMap (`applySparse` t) ss)
applySparse SpScal t = t
instance ApplySparse SMTy where
@@ -41,34 +59,23 @@ instance ApplySparse SMTy where
applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2)
applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t)
applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t)
+ applySparse (SpArrIdx s) (SMTArr n t) = SMTPair (mkTup SMTNil SMTPair (sreplicate n (knownMTy @TIx))) (applySparse s t)
applySparse SpScal t = t
class IsSubType s where
type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint
subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t'
- subtTrans :: s a b -> s b c -> s a c
subtFull :: IsSubTypeSubject s f => f t -> s t t
instance IsSubType (:~:) where
type IsSubTypeSubject (:~:) f = ()
subtApply = gcastWith
- subtTrans = trans
subtFull _ = Refl
instance IsSubType Sparse where
type IsSubTypeSubject Sparse f = f ~ SMTy
subtApply = applySparse
-
- subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2)
- subtTrans _ SpAbsent = SpAbsent
- subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b)
- subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b)
- subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2)
- subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2)
- subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2)
- subtTrans SpScal SpScal = SpScal
-
subtFull = spDense
spDense :: SMTy t -> Sparse t t
@@ -95,6 +102,7 @@ isDense (SMTMaybe t) (SpMaybe s)
isDense (SMTArr _ t) (SpArr s)
| Just Refl <- isDense t s = Just Refl
| otherwise = Nothing
+isDense SMTArr{} SpArrIdx{} = Nothing
isDense (SMTScal _) SpScal = Just Refl
isAbsent :: Sparse t t' -> Bool
@@ -104,4 +112,5 @@ isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2
isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2
isAbsent (SpMaybe s) = isAbsent s
isAbsent (SpArr s) = isAbsent s
+isAbsent (SpArrIdx s) = isAbsent s
isAbsent SpScal = False