aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/AST/Sparse
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2026-02-08 15:43:02 +0100
committerTom Smeding <tom@tomsmeding.com>2026-02-12 20:44:47 +0100
commit62796be35e6e768147aab70ba0beeb94c058c714 (patch)
treedd43c8c2f37c59308b6b7d503fd25420621b0ab9 /src/CHAD/AST/Sparse
parentc2831ef0f8be71f2a72ee4eee446e2ac473fb638 (diff)
WIP (continue in UnMonoid)
Diffstat (limited to 'src/CHAD/AST/Sparse')
-rw-r--r--src/CHAD/AST/Sparse/Types.hs18
1 files changed, 13 insertions, 5 deletions
diff --git a/src/CHAD/AST/Sparse/Types.hs b/src/CHAD/AST/Sparse/Types.hs
index 930475a..9a4cf99 100644
--- a/src/CHAD/AST/Sparse/Types.hs
+++ b/src/CHAD/AST/Sparse/Types.hs
@@ -25,19 +25,20 @@ data Sparse t t' where
SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t')
SpArrIdx :: SList (Sparse t) t's -> Sparse (TArr n t) (MultiHot n t's)
SpScal :: Sparse (TScal t) (TScal t)
+ SpIdxPair :: Sparse t t' -> Sparse (TIdxPair n t) (TIdxPair n t')
deriving instance Show (Sparse t t')
type family MultiHot n t's where
MultiHot n '[] = TNil
- MultiHot n (t' : t's) = TPair (TPair (Tup (Replicate n TIx)) t') (MultiHot n t's)
+ MultiHot n (t' : t's) = TPair (TIdxPair n t') (MultiHot n t's)
tMultiHot :: SNat n -> SList STy ts -> STy (MultiHot n ts)
tMultiHot _ SNil = STNil
-tMultiHot n (t `SCons` ts) = STPair (STPair (tTup (sreplicate n tIx)) t) (tMultiHot n ts)
+tMultiHot n (t `SCons` ts) = STPair (STIdxPair n t) (tMultiHot n ts)
mtMultiHot :: SNat n -> SList SMTy ts -> SMTy (MultiHot n ts)
mtMultiHot _ SNil = SMTNil
-mtMultiHot n (t `SCons` ts) = SMTPair (SMTPair (tTup (sreplicate n tIx)) t) (tMultiHot n ts)
+mtMultiHot n (t `SCons` ts) = SMTPair (SMTIdxPair n t) (mtMultiHot n ts)
class ApplySparse f where
applySparse :: Sparse t t' -> f t -> f t'
@@ -51,6 +52,7 @@ instance ApplySparse STy where
applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t)
applySparse (SpArrIdx ss) (STArr n t) = tMultiHot n (slistMap (`applySparse` t) ss)
applySparse SpScal t = t
+ applySparse (SpIdxPair s) (STIdxPair n t) = STIdxPair n (applySparse s t)
instance ApplySparse SMTy where
applySparse (SpSparse s) t = SMTMaybe (applySparse s t)
@@ -59,8 +61,9 @@ instance ApplySparse SMTy where
applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2)
applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t)
applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t)
- applySparse (SpArrIdx s) (SMTArr n t) = SMTPair (mkTup SMTNil SMTPair (sreplicate n (knownMTy @TIx))) (applySparse s t)
+ applySparse (SpArrIdx l) (SMTArr n t) = mtMultiHot n (slistMap (`applySparse` t) l)
applySparse SpScal t = t
+ applySparse (SpIdxPair s) (SMTIdxPair n t) = SMTIdxPair n (applySparse s t)
class IsSubType s where
@@ -85,6 +88,7 @@ spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2)
spDense (SMTMaybe t) = SpMaybe (spDense t)
spDense (SMTArr _ t) = SpArr (spDense t)
spDense (SMTScal _) = SpScal
+spDense (SMTIdxPair _ t) = SpIdxPair (spDense t)
isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t')
isDense SMTNil SpAbsent = Just Refl
@@ -104,6 +108,9 @@ isDense (SMTArr _ t) (SpArr s)
| otherwise = Nothing
isDense SMTArr{} SpArrIdx{} = Nothing
isDense (SMTScal _) SpScal = Just Refl
+isDense (SMTIdxPair _ t) (SpIdxPair s)
+ | Just Refl <- isDense t s = Just Refl
+ | otherwise = Nothing
isAbsent :: Sparse t t' -> Bool
isAbsent (SpSparse s) = isAbsent s
@@ -112,5 +119,6 @@ isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2
isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2
isAbsent (SpMaybe s) = isAbsent s
isAbsent (SpArr s) = isAbsent s
-isAbsent (SpArrIdx s) = isAbsent s
+isAbsent (SpArrIdx l) = and (unSList isAbsent l)
isAbsent SpScal = False
+isAbsent (SpIdxPair s) = isAbsent s