diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2026-01-26 23:37:55 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2026-01-26 23:37:55 +0100 |
| commit | b4f988cb1490ed31ab225323b33448667b8578c0 (patch) | |
| tree | d048e70e33f2e2787aae68a9b671b78094c05c43 /src | |
| parent | a9e6c72eff3bee8d45e0d906e8cd027066e04793 (diff) | |
Multihot cotangents WIP (doesn't work)multihot-cotangents
The idea is sound but for a smaller source language. Notes also in
Obsidian, but the theory so far is that dropping support for nested
arrays makes this possible, although making the result type-safe (i.e.
not have partial functions in a bunch of places) would require making
the lack of nested array support explicit in the embedded type system,
i.e. have Accelerate-like stratification.
The point is that multihots can be added heterogeneously using
plusSparseS but not homogeneously with EPlus or plusSparse, because the
indices might differ between the summands. Thus as long as we never need
to homogeneously sum multihot cotangents, we're golden.
Now the crucial observation is that we only need plus to be homogeneous
on array elements. So if array elements cannot themselves be arrays,
i.e. we drop support for nested arrays, no homogeneous plus of multihot
array cotangents is needed, and we can have static multihots.
Diffstat (limited to 'src')
| -rw-r--r-- | src/CHAD/AST.hs | 58 | ||||
| -rw-r--r-- | src/CHAD/AST/Env.hs | 4 | ||||
| -rw-r--r-- | src/CHAD/AST/Sparse/Types.hs | 33 | ||||
| -rw-r--r-- | src/CHAD/AST/Types.hs | 59 | ||||
| -rw-r--r-- | src/CHAD/AST/UnMonoid.hs | 5 | ||||
| -rw-r--r-- | src/CHAD/Drev.hs | 3 |
6 files changed, 88 insertions, 74 deletions
diff --git a/src/CHAD/AST.hs b/src/CHAD/AST.hs index b795070..3f6dfc4 100644 --- a/src/CHAD/AST.hs +++ b/src/CHAD/AST.hs @@ -442,64 +442,6 @@ subst' f w = \case weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i)) -class KnownScalTy t where knownScalTy :: SScalTy t -instance KnownScalTy TI32 where knownScalTy = STI32 -instance KnownScalTy TI64 where knownScalTy = STI64 -instance KnownScalTy TF32 where knownScalTy = STF32 -instance KnownScalTy TF64 where knownScalTy = STF64 -instance KnownScalTy TBool where knownScalTy = STBool - -class KnownTy t where knownTy :: STy t -instance KnownTy TNil where knownTy = STNil -instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy -instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy -instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy -instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy -instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy -instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy -instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy - -class KnownMTy t where knownMTy :: SMTy t -instance KnownMTy TNil where knownMTy = SMTNil -instance (KnownMTy s, KnownMTy t) => KnownMTy (TPair s t) where knownMTy = SMTPair knownMTy knownMTy -instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy -instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy -instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy -instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy - -class KnownEnv env where knownEnv :: SList STy env -instance KnownEnv '[] where knownEnv = SNil -instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv - -styKnown :: STy t -> Dict (KnownTy t) -styKnown STNil = Dict -styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict -styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict -styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict -styKnown (STMaybe t) | Dict <- styKnown t = Dict -styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict -styKnown (STScal t) | Dict <- sscaltyKnown t = Dict -styKnown (STAccum t) | Dict <- smtyKnown t = Dict - -smtyKnown :: SMTy t -> Dict (KnownMTy t) -smtyKnown SMTNil = Dict -smtyKnown (SMTPair a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict -smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict -smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict -smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict -smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict - -sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t) -sscaltyKnown STI32 = Dict -sscaltyKnown STI64 = Dict -sscaltyKnown STF32 = Dict -sscaltyKnown STF64 = Dict -sscaltyKnown STBool = Dict - -envKnown :: SList STy env -> Dict (KnownEnv env) -envKnown SNil = Dict -envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict - cheapExpr :: Expr x env t -> Bool cheapExpr = \case EVar{} -> True diff --git a/src/CHAD/AST/Env.hs b/src/CHAD/AST/Env.hs index 8e6b745..40b6ca2 100644 --- a/src/CHAD/AST/Env.hs +++ b/src/CHAD/AST/Env.hs @@ -53,9 +53,9 @@ subenvOnehot (SCons _ env) IZ sp = SEYes sp (subenvNone env) subenvOnehot (SCons _ env) (IS i) sp = SENo (subenvOnehot env i sp) subenvOnehot SNil i _ = case i of {} -subenvCompose :: IsSubType s => Subenv' s env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3 +subenvCompose :: IsSubType s => Subenv' (:~:) env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3 subenvCompose SETop SETop = SETop -subenvCompose (SEYes s1 sub1) (SEYes s2 sub2) = SEYes (subtTrans s1 s2) (subenvCompose sub1 sub2) +subenvCompose (SEYes Refl sub1) (SEYes s2 sub2) = SEYes s2 (subenvCompose sub1 sub2) subenvCompose (SEYes _ sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2) subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2) diff --git a/src/CHAD/AST/Sparse/Types.hs b/src/CHAD/AST/Sparse/Types.hs index 8f41ba4..930475a 100644 --- a/src/CHAD/AST/Sparse/Types.hs +++ b/src/CHAD/AST/Sparse/Types.hs @@ -1,14 +1,18 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} module CHAD.AST.Sparse.Types where import Data.Kind (Type, Constraint) import Data.Type.Equality import CHAD.AST.Types +import CHAD.Data data Sparse t t' where @@ -19,9 +23,22 @@ data Sparse t t' where SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b') SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t') SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') + SpArrIdx :: SList (Sparse t) t's -> Sparse (TArr n t) (MultiHot n t's) SpScal :: Sparse (TScal t) (TScal t) deriving instance Show (Sparse t t') +type family MultiHot n t's where + MultiHot n '[] = TNil + MultiHot n (t' : t's) = TPair (TPair (Tup (Replicate n TIx)) t') (MultiHot n t's) + +tMultiHot :: SNat n -> SList STy ts -> STy (MultiHot n ts) +tMultiHot _ SNil = STNil +tMultiHot n (t `SCons` ts) = STPair (STPair (tTup (sreplicate n tIx)) t) (tMultiHot n ts) + +mtMultiHot :: SNat n -> SList SMTy ts -> SMTy (MultiHot n ts) +mtMultiHot _ SNil = SMTNil +mtMultiHot n (t `SCons` ts) = SMTPair (SMTPair (tTup (sreplicate n tIx)) t) (tMultiHot n ts) + class ApplySparse f where applySparse :: Sparse t t' -> f t -> f t' @@ -32,6 +49,7 @@ instance ApplySparse STy where applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2) applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t) applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) + applySparse (SpArrIdx ss) (STArr n t) = tMultiHot n (slistMap (`applySparse` t) ss) applySparse SpScal t = t instance ApplySparse SMTy where @@ -41,34 +59,23 @@ instance ApplySparse SMTy where applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2) applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t) applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t) + applySparse (SpArrIdx s) (SMTArr n t) = SMTPair (mkTup SMTNil SMTPair (sreplicate n (knownMTy @TIx))) (applySparse s t) applySparse SpScal t = t class IsSubType s where type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t' - subtTrans :: s a b -> s b c -> s a c subtFull :: IsSubTypeSubject s f => f t -> s t t instance IsSubType (:~:) where type IsSubTypeSubject (:~:) f = () subtApply = gcastWith - subtTrans = trans subtFull _ = Refl instance IsSubType Sparse where type IsSubTypeSubject Sparse f = f ~ SMTy subtApply = applySparse - - subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2) - subtTrans _ SpAbsent = SpAbsent - subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b) - subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b) - subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2) - subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2) - subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2) - subtTrans SpScal SpScal = SpScal - subtFull = spDense spDense :: SMTy t -> Sparse t t @@ -95,6 +102,7 @@ isDense (SMTMaybe t) (SpMaybe s) isDense (SMTArr _ t) (SpArr s) | Just Refl <- isDense t s = Just Refl | otherwise = Nothing +isDense SMTArr{} SpArrIdx{} = Nothing isDense (SMTScal _) SpScal = Just Refl isAbsent :: Sparse t t' -> Bool @@ -104,4 +112,5 @@ isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2 isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2 isAbsent (SpMaybe s) = isAbsent s isAbsent (SpArr s) = isAbsent s +isAbsent (SpArrIdx s) = isAbsent s isAbsent SpScal = False diff --git a/src/CHAD/AST/Types.hs b/src/CHAD/AST/Types.hs index f0feb55..bec2201 100644 --- a/src/CHAD/AST/Types.hs +++ b/src/CHAD/AST/Types.hs @@ -77,6 +77,7 @@ data SMTy t where SMTMaybe :: SMTy a -> SMTy (TMaybe a) SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t) SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t) + SMTData :: STy a -> SMTy a -- ^ inclusion of non-monoidal information deriving instance Show (SMTy t) instance GCompare SMTy where @@ -215,3 +216,61 @@ unTup unpack (_ `SCons` list) tup = type family InvTup core env where InvTup core '[] = core InvTup core (t : ts) = InvTup (TPair core t) ts + +class KnownScalTy t where knownScalTy :: SScalTy t +instance KnownScalTy TI32 where knownScalTy = STI32 +instance KnownScalTy TI64 where knownScalTy = STI64 +instance KnownScalTy TF32 where knownScalTy = STF32 +instance KnownScalTy TF64 where knownScalTy = STF64 +instance KnownScalTy TBool where knownScalTy = STBool + +class KnownTy t where knownTy :: STy t +instance KnownTy TNil where knownTy = STNil +instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy +instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy +instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy +instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy +instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy +instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy +instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy + +class KnownMTy t where knownMTy :: SMTy t +instance KnownMTy TNil where knownMTy = SMTNil +instance (KnownMTy s, KnownMTy t) => KnownMTy (TPair s t) where knownMTy = SMTPair knownMTy knownMTy +instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy +instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy +instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy +instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy + +class KnownEnv env where knownEnv :: SList STy env +instance KnownEnv '[] where knownEnv = SNil +instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv + +styKnown :: STy t -> Dict (KnownTy t) +styKnown STNil = Dict +styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict +styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict +styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict +styKnown (STMaybe t) | Dict <- styKnown t = Dict +styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict +styKnown (STScal t) | Dict <- sscaltyKnown t = Dict +styKnown (STAccum t) | Dict <- smtyKnown t = Dict + +smtyKnown :: SMTy t -> Dict (KnownMTy t) +smtyKnown SMTNil = Dict +smtyKnown (SMTPair a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict +smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict +smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict +smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict +smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict + +sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t) +sscaltyKnown STI32 = Dict +sscaltyKnown STI64 = Dict +sscaltyKnown STF32 = Dict +sscaltyKnown STF64 = Dict +sscaltyKnown STBool = Dict + +envKnown :: SList STy env -> Dict (KnownEnv env) +envKnown SNil = Dict +envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict diff --git a/src/CHAD/AST/UnMonoid.hs b/src/CHAD/AST/UnMonoid.hs index d3cad25..1a66cdf 100644 --- a/src/CHAD/AST/UnMonoid.hs +++ b/src/CHAD/AST/UnMonoid.hs @@ -212,6 +212,11 @@ accumulateSparse topty topsp arg accum = case (topty, topsp) of (EIdx ext (evar (IS IZ)) (EVar ext tn IZ)) (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $ ENil ext + (SMTArr _ t, SpArrIdx s) -> + eunPair arg $ \w1 e1 e2 -> + elet (accumulateSparse t s e2 + (\w prj idx val -> accum (w .> w1) (SAPArrIdx prj) (EPair ext (weakenExpr w e1) idx) val)) $ + ENil ext acPrjCompose :: SAIDense dense diff --git a/src/CHAD/Drev.hs b/src/CHAD/Drev.hs index bfa964b..ee92782 100644 --- a/src/CHAD/Drev.hs +++ b/src/CHAD/Drev.hs @@ -1009,11 +1009,10 @@ drev des accumMap sd = \case let smallE = unsafeWeakenWithSubenv usedSub e in subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> case drev usedDes (VarMap.subMap subAccumUsed accumMap) sd smallE of { Ret e0 subtape _ sub e2 -> - let subMergeUsed' = subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E subMergeUsed) in Ret (collectBindings (desD1E des) subD1eUsed) (subenvAll (desD1E usedDes)) (weakenExpr (wSinks (desD1E usedDes)) $ drevPrimal des e) - (subenvCompose subMergeUsed' sub) + (subenvCompose (subenvD2E subMergeUsed) sub) (letBinds (fst (weakenBindingsE (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $ weakenExpr (autoWeak (#d (auto1 @sd) |
