aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/AST/Sparse
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2026-01-26 23:37:55 +0100
committerTom Smeding <tom@tomsmeding.com>2026-01-26 23:37:55 +0100
commitb4f988cb1490ed31ab225323b33448667b8578c0 (patch)
treed048e70e33f2e2787aae68a9b671b78094c05c43 /src/CHAD/AST/Sparse
parenta9e6c72eff3bee8d45e0d906e8cd027066e04793 (diff)
Multihot cotangents WIP (doesn't work)multihot-cotangents
The idea is sound but for a smaller source language. Notes also in Obsidian, but the theory so far is that dropping support for nested arrays makes this possible, although making the result type-safe (i.e. not have partial functions in a bunch of places) would require making the lack of nested array support explicit in the embedded type system, i.e. have Accelerate-like stratification. The point is that multihots can be added heterogeneously using plusSparseS but not homogeneously with EPlus or plusSparse, because the indices might differ between the summands. Thus as long as we never need to homogeneously sum multihot cotangents, we're golden. Now the crucial observation is that we only need plus to be homogeneous on array elements. So if array elements cannot themselves be arrays, i.e. we drop support for nested arrays, no homogeneous plus of multihot array cotangents is needed, and we can have static multihots.
Diffstat (limited to 'src/CHAD/AST/Sparse')
-rw-r--r--src/CHAD/AST/Sparse/Types.hs33
1 files changed, 21 insertions, 12 deletions
diff --git a/src/CHAD/AST/Sparse/Types.hs b/src/CHAD/AST/Sparse/Types.hs
index 8f41ba4..930475a 100644
--- a/src/CHAD/AST/Sparse/Types.hs
+++ b/src/CHAD/AST/Sparse/Types.hs
@@ -1,14 +1,18 @@
+{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
module CHAD.AST.Sparse.Types where
import Data.Kind (Type, Constraint)
import Data.Type.Equality
import CHAD.AST.Types
+import CHAD.Data
data Sparse t t' where
@@ -19,9 +23,22 @@ data Sparse t t' where
SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b')
SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t')
SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t')
+ SpArrIdx :: SList (Sparse t) t's -> Sparse (TArr n t) (MultiHot n t's)
SpScal :: Sparse (TScal t) (TScal t)
deriving instance Show (Sparse t t')
+type family MultiHot n t's where
+ MultiHot n '[] = TNil
+ MultiHot n (t' : t's) = TPair (TPair (Tup (Replicate n TIx)) t') (MultiHot n t's)
+
+tMultiHot :: SNat n -> SList STy ts -> STy (MultiHot n ts)
+tMultiHot _ SNil = STNil
+tMultiHot n (t `SCons` ts) = STPair (STPair (tTup (sreplicate n tIx)) t) (tMultiHot n ts)
+
+mtMultiHot :: SNat n -> SList SMTy ts -> SMTy (MultiHot n ts)
+mtMultiHot _ SNil = SMTNil
+mtMultiHot n (t `SCons` ts) = SMTPair (SMTPair (tTup (sreplicate n tIx)) t) (tMultiHot n ts)
+
class ApplySparse f where
applySparse :: Sparse t t' -> f t -> f t'
@@ -32,6 +49,7 @@ instance ApplySparse STy where
applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2)
applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t)
applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t)
+ applySparse (SpArrIdx ss) (STArr n t) = tMultiHot n (slistMap (`applySparse` t) ss)
applySparse SpScal t = t
instance ApplySparse SMTy where
@@ -41,34 +59,23 @@ instance ApplySparse SMTy where
applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2)
applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t)
applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t)
+ applySparse (SpArrIdx s) (SMTArr n t) = SMTPair (mkTup SMTNil SMTPair (sreplicate n (knownMTy @TIx))) (applySparse s t)
applySparse SpScal t = t
class IsSubType s where
type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint
subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t'
- subtTrans :: s a b -> s b c -> s a c
subtFull :: IsSubTypeSubject s f => f t -> s t t
instance IsSubType (:~:) where
type IsSubTypeSubject (:~:) f = ()
subtApply = gcastWith
- subtTrans = trans
subtFull _ = Refl
instance IsSubType Sparse where
type IsSubTypeSubject Sparse f = f ~ SMTy
subtApply = applySparse
-
- subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2)
- subtTrans _ SpAbsent = SpAbsent
- subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b)
- subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b)
- subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2)
- subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2)
- subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2)
- subtTrans SpScal SpScal = SpScal
-
subtFull = spDense
spDense :: SMTy t -> Sparse t t
@@ -95,6 +102,7 @@ isDense (SMTMaybe t) (SpMaybe s)
isDense (SMTArr _ t) (SpArr s)
| Just Refl <- isDense t s = Just Refl
| otherwise = Nothing
+isDense SMTArr{} SpArrIdx{} = Nothing
isDense (SMTScal _) SpScal = Just Refl
isAbsent :: Sparse t t' -> Bool
@@ -104,4 +112,5 @@ isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2
isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2
isAbsent (SpMaybe s) = isAbsent s
isAbsent (SpArr s) = isAbsent s
+isAbsent (SpArrIdx s) = isAbsent s
isAbsent SpScal = False