From b4f988cb1490ed31ab225323b33448667b8578c0 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 26 Jan 2026 23:37:55 +0100 Subject: Multihot cotangents WIP (doesn't work) The idea is sound but for a smaller source language. Notes also in Obsidian, but the theory so far is that dropping support for nested arrays makes this possible, although making the result type-safe (i.e. not have partial functions in a bunch of places) would require making the lack of nested array support explicit in the embedded type system, i.e. have Accelerate-like stratification. The point is that multihots can be added heterogeneously using plusSparseS but not homogeneously with EPlus or plusSparse, because the indices might differ between the summands. Thus as long as we never need to homogeneously sum multihot cotangents, we're golden. Now the crucial observation is that we only need plus to be homogeneous on array elements. So if array elements cannot themselves be arrays, i.e. we drop support for nested arrays, no homogeneous plus of multihot array cotangents is needed, and we can have static multihots. --- src/CHAD/AST/Sparse/Types.hs | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) (limited to 'src/CHAD/AST/Sparse') diff --git a/src/CHAD/AST/Sparse/Types.hs b/src/CHAD/AST/Sparse/Types.hs index 8f41ba4..930475a 100644 --- a/src/CHAD/AST/Sparse/Types.hs +++ b/src/CHAD/AST/Sparse/Types.hs @@ -1,14 +1,18 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} module CHAD.AST.Sparse.Types where import Data.Kind (Type, Constraint) import Data.Type.Equality import CHAD.AST.Types +import CHAD.Data data Sparse t t' where @@ -19,9 +23,22 @@ data Sparse t t' where SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b') SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t') SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') + SpArrIdx :: SList (Sparse t) t's -> Sparse (TArr n t) (MultiHot n t's) SpScal :: Sparse (TScal t) (TScal t) deriving instance Show (Sparse t t') +type family MultiHot n t's where + MultiHot n '[] = TNil + MultiHot n (t' : t's) = TPair (TPair (Tup (Replicate n TIx)) t') (MultiHot n t's) + +tMultiHot :: SNat n -> SList STy ts -> STy (MultiHot n ts) +tMultiHot _ SNil = STNil +tMultiHot n (t `SCons` ts) = STPair (STPair (tTup (sreplicate n tIx)) t) (tMultiHot n ts) + +mtMultiHot :: SNat n -> SList SMTy ts -> SMTy (MultiHot n ts) +mtMultiHot _ SNil = SMTNil +mtMultiHot n (t `SCons` ts) = SMTPair (SMTPair (tTup (sreplicate n tIx)) t) (tMultiHot n ts) + class ApplySparse f where applySparse :: Sparse t t' -> f t -> f t' @@ -32,6 +49,7 @@ instance ApplySparse STy where applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2) applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t) applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) + applySparse (SpArrIdx ss) (STArr n t) = tMultiHot n (slistMap (`applySparse` t) ss) applySparse SpScal t = t instance ApplySparse SMTy where @@ -41,34 +59,23 @@ instance ApplySparse SMTy where applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2) applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t) applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t) + applySparse (SpArrIdx s) (SMTArr n t) = SMTPair (mkTup SMTNil SMTPair (sreplicate n (knownMTy @TIx))) (applySparse s t) applySparse SpScal t = t class IsSubType s where type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t' - subtTrans :: s a b -> s b c -> s a c subtFull :: IsSubTypeSubject s f => f t -> s t t instance IsSubType (:~:) where type IsSubTypeSubject (:~:) f = () subtApply = gcastWith - subtTrans = trans subtFull _ = Refl instance IsSubType Sparse where type IsSubTypeSubject Sparse f = f ~ SMTy subtApply = applySparse - - subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2) - subtTrans _ SpAbsent = SpAbsent - subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b) - subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b) - subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2) - subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2) - subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2) - subtTrans SpScal SpScal = SpScal - subtFull = spDense spDense :: SMTy t -> Sparse t t @@ -95,6 +102,7 @@ isDense (SMTMaybe t) (SpMaybe s) isDense (SMTArr _ t) (SpArr s) | Just Refl <- isDense t s = Just Refl | otherwise = Nothing +isDense SMTArr{} SpArrIdx{} = Nothing isDense (SMTScal _) SpScal = Just Refl isAbsent :: Sparse t t' -> Bool @@ -104,4 +112,5 @@ isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2 isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2 isAbsent (SpMaybe s) = isAbsent s isAbsent (SpArr s) = isAbsent s +isAbsent (SpArrIdx s) = isAbsent s isAbsent SpScal = False -- cgit v1.2.3-70-g09d2