From 62796be35e6e768147aab70ba0beeb94c058c714 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 8 Feb 2026 15:43:02 +0100 Subject: WIP (continue in UnMonoid) --- src/CHAD/AST/Sparse/Types.hs | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) (limited to 'src/CHAD/AST/Sparse') diff --git a/src/CHAD/AST/Sparse/Types.hs b/src/CHAD/AST/Sparse/Types.hs index 930475a..9a4cf99 100644 --- a/src/CHAD/AST/Sparse/Types.hs +++ b/src/CHAD/AST/Sparse/Types.hs @@ -25,19 +25,20 @@ data Sparse t t' where SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') SpArrIdx :: SList (Sparse t) t's -> Sparse (TArr n t) (MultiHot n t's) SpScal :: Sparse (TScal t) (TScal t) + SpIdxPair :: Sparse t t' -> Sparse (TIdxPair n t) (TIdxPair n t') deriving instance Show (Sparse t t') type family MultiHot n t's where MultiHot n '[] = TNil - MultiHot n (t' : t's) = TPair (TPair (Tup (Replicate n TIx)) t') (MultiHot n t's) + MultiHot n (t' : t's) = TPair (TIdxPair n t') (MultiHot n t's) tMultiHot :: SNat n -> SList STy ts -> STy (MultiHot n ts) tMultiHot _ SNil = STNil -tMultiHot n (t `SCons` ts) = STPair (STPair (tTup (sreplicate n tIx)) t) (tMultiHot n ts) +tMultiHot n (t `SCons` ts) = STPair (STIdxPair n t) (tMultiHot n ts) mtMultiHot :: SNat n -> SList SMTy ts -> SMTy (MultiHot n ts) mtMultiHot _ SNil = SMTNil -mtMultiHot n (t `SCons` ts) = SMTPair (SMTPair (tTup (sreplicate n tIx)) t) (tMultiHot n ts) +mtMultiHot n (t `SCons` ts) = SMTPair (SMTIdxPair n t) (mtMultiHot n ts) class ApplySparse f where applySparse :: Sparse t t' -> f t -> f t' @@ -51,6 +52,7 @@ instance ApplySparse STy where applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) applySparse (SpArrIdx ss) (STArr n t) = tMultiHot n (slistMap (`applySparse` t) ss) applySparse SpScal t = t + applySparse (SpIdxPair s) (STIdxPair n t) = STIdxPair n (applySparse s t) instance ApplySparse SMTy where applySparse (SpSparse s) t = SMTMaybe (applySparse s t) @@ -59,8 +61,9 @@ instance ApplySparse SMTy where applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2) applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t) applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t) - applySparse (SpArrIdx s) (SMTArr n t) = SMTPair (mkTup SMTNil SMTPair (sreplicate n (knownMTy @TIx))) (applySparse s t) + applySparse (SpArrIdx l) (SMTArr n t) = mtMultiHot n (slistMap (`applySparse` t) l) applySparse SpScal t = t + applySparse (SpIdxPair s) (SMTIdxPair n t) = SMTIdxPair n (applySparse s t) class IsSubType s where @@ -85,6 +88,7 @@ spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2) spDense (SMTMaybe t) = SpMaybe (spDense t) spDense (SMTArr _ t) = SpArr (spDense t) spDense (SMTScal _) = SpScal +spDense (SMTIdxPair _ t) = SpIdxPair (spDense t) isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t') isDense SMTNil SpAbsent = Just Refl @@ -104,6 +108,9 @@ isDense (SMTArr _ t) (SpArr s) | otherwise = Nothing isDense SMTArr{} SpArrIdx{} = Nothing isDense (SMTScal _) SpScal = Just Refl +isDense (SMTIdxPair _ t) (SpIdxPair s) + | Just Refl <- isDense t s = Just Refl + | otherwise = Nothing isAbsent :: Sparse t t' -> Bool isAbsent (SpSparse s) = isAbsent s @@ -112,5 +119,6 @@ isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2 isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2 isAbsent (SpMaybe s) = isAbsent s isAbsent (SpArr s) = isAbsent s -isAbsent (SpArrIdx s) = isAbsent s +isAbsent (SpArrIdx l) = and (unSList isAbsent l) isAbsent SpScal = False +isAbsent (SpIdxPair s) = isAbsent s -- cgit v1.2.3-70-g09d2