aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2026-01-26 23:37:55 +0100
committerTom Smeding <tom@tomsmeding.com>2026-01-26 23:37:55 +0100
commitb4f988cb1490ed31ab225323b33448667b8578c0 (patch)
treed048e70e33f2e2787aae68a9b671b78094c05c43 /src
parenta9e6c72eff3bee8d45e0d906e8cd027066e04793 (diff)
Multihot cotangents WIP (doesn't work)multihot-cotangents
The idea is sound but for a smaller source language. Notes also in Obsidian, but the theory so far is that dropping support for nested arrays makes this possible, although making the result type-safe (i.e. not have partial functions in a bunch of places) would require making the lack of nested array support explicit in the embedded type system, i.e. have Accelerate-like stratification. The point is that multihots can be added heterogeneously using plusSparseS but not homogeneously with EPlus or plusSparse, because the indices might differ between the summands. Thus as long as we never need to homogeneously sum multihot cotangents, we're golden. Now the crucial observation is that we only need plus to be homogeneous on array elements. So if array elements cannot themselves be arrays, i.e. we drop support for nested arrays, no homogeneous plus of multihot array cotangents is needed, and we can have static multihots.
Diffstat (limited to 'src')
-rw-r--r--src/CHAD/AST.hs58
-rw-r--r--src/CHAD/AST/Env.hs4
-rw-r--r--src/CHAD/AST/Sparse/Types.hs33
-rw-r--r--src/CHAD/AST/Types.hs59
-rw-r--r--src/CHAD/AST/UnMonoid.hs5
-rw-r--r--src/CHAD/Drev.hs3
6 files changed, 88 insertions, 74 deletions
diff --git a/src/CHAD/AST.hs b/src/CHAD/AST.hs
index b795070..3f6dfc4 100644
--- a/src/CHAD/AST.hs
+++ b/src/CHAD/AST.hs
@@ -442,64 +442,6 @@ subst' f w = \case
weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t
weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i))
-class KnownScalTy t where knownScalTy :: SScalTy t
-instance KnownScalTy TI32 where knownScalTy = STI32
-instance KnownScalTy TI64 where knownScalTy = STI64
-instance KnownScalTy TF32 where knownScalTy = STF32
-instance KnownScalTy TF64 where knownScalTy = STF64
-instance KnownScalTy TBool where knownScalTy = STBool
-
-class KnownTy t where knownTy :: STy t
-instance KnownTy TNil where knownTy = STNil
-instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy
-instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy
-instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy
-instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy
-instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy
-instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy
-instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy
-
-class KnownMTy t where knownMTy :: SMTy t
-instance KnownMTy TNil where knownMTy = SMTNil
-instance (KnownMTy s, KnownMTy t) => KnownMTy (TPair s t) where knownMTy = SMTPair knownMTy knownMTy
-instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy
-instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy
-instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy
-instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy
-
-class KnownEnv env where knownEnv :: SList STy env
-instance KnownEnv '[] where knownEnv = SNil
-instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv
-
-styKnown :: STy t -> Dict (KnownTy t)
-styKnown STNil = Dict
-styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
-styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
-styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
-styKnown (STMaybe t) | Dict <- styKnown t = Dict
-styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict
-styKnown (STScal t) | Dict <- sscaltyKnown t = Dict
-styKnown (STAccum t) | Dict <- smtyKnown t = Dict
-
-smtyKnown :: SMTy t -> Dict (KnownMTy t)
-smtyKnown SMTNil = Dict
-smtyKnown (SMTPair a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict
-smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict
-smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict
-smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict
-smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict
-
-sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t)
-sscaltyKnown STI32 = Dict
-sscaltyKnown STI64 = Dict
-sscaltyKnown STF32 = Dict
-sscaltyKnown STF64 = Dict
-sscaltyKnown STBool = Dict
-
-envKnown :: SList STy env -> Dict (KnownEnv env)
-envKnown SNil = Dict
-envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict
-
cheapExpr :: Expr x env t -> Bool
cheapExpr = \case
EVar{} -> True
diff --git a/src/CHAD/AST/Env.hs b/src/CHAD/AST/Env.hs
index 8e6b745..40b6ca2 100644
--- a/src/CHAD/AST/Env.hs
+++ b/src/CHAD/AST/Env.hs
@@ -53,9 +53,9 @@ subenvOnehot (SCons _ env) IZ sp = SEYes sp (subenvNone env)
subenvOnehot (SCons _ env) (IS i) sp = SENo (subenvOnehot env i sp)
subenvOnehot SNil i _ = case i of {}
-subenvCompose :: IsSubType s => Subenv' s env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3
+subenvCompose :: IsSubType s => Subenv' (:~:) env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3
subenvCompose SETop SETop = SETop
-subenvCompose (SEYes s1 sub1) (SEYes s2 sub2) = SEYes (subtTrans s1 s2) (subenvCompose sub1 sub2)
+subenvCompose (SEYes Refl sub1) (SEYes s2 sub2) = SEYes s2 (subenvCompose sub1 sub2)
subenvCompose (SEYes _ sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2)
subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2)
diff --git a/src/CHAD/AST/Sparse/Types.hs b/src/CHAD/AST/Sparse/Types.hs
index 8f41ba4..930475a 100644
--- a/src/CHAD/AST/Sparse/Types.hs
+++ b/src/CHAD/AST/Sparse/Types.hs
@@ -1,14 +1,18 @@
+{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
module CHAD.AST.Sparse.Types where
import Data.Kind (Type, Constraint)
import Data.Type.Equality
import CHAD.AST.Types
+import CHAD.Data
data Sparse t t' where
@@ -19,9 +23,22 @@ data Sparse t t' where
SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b')
SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t')
SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t')
+ SpArrIdx :: SList (Sparse t) t's -> Sparse (TArr n t) (MultiHot n t's)
SpScal :: Sparse (TScal t) (TScal t)
deriving instance Show (Sparse t t')
+type family MultiHot n t's where
+ MultiHot n '[] = TNil
+ MultiHot n (t' : t's) = TPair (TPair (Tup (Replicate n TIx)) t') (MultiHot n t's)
+
+tMultiHot :: SNat n -> SList STy ts -> STy (MultiHot n ts)
+tMultiHot _ SNil = STNil
+tMultiHot n (t `SCons` ts) = STPair (STPair (tTup (sreplicate n tIx)) t) (tMultiHot n ts)
+
+mtMultiHot :: SNat n -> SList SMTy ts -> SMTy (MultiHot n ts)
+mtMultiHot _ SNil = SMTNil
+mtMultiHot n (t `SCons` ts) = SMTPair (SMTPair (tTup (sreplicate n tIx)) t) (tMultiHot n ts)
+
class ApplySparse f where
applySparse :: Sparse t t' -> f t -> f t'
@@ -32,6 +49,7 @@ instance ApplySparse STy where
applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2)
applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t)
applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t)
+ applySparse (SpArrIdx ss) (STArr n t) = tMultiHot n (slistMap (`applySparse` t) ss)
applySparse SpScal t = t
instance ApplySparse SMTy where
@@ -41,34 +59,23 @@ instance ApplySparse SMTy where
applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2)
applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t)
applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t)
+ applySparse (SpArrIdx s) (SMTArr n t) = SMTPair (mkTup SMTNil SMTPair (sreplicate n (knownMTy @TIx))) (applySparse s t)
applySparse SpScal t = t
class IsSubType s where
type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint
subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t'
- subtTrans :: s a b -> s b c -> s a c
subtFull :: IsSubTypeSubject s f => f t -> s t t
instance IsSubType (:~:) where
type IsSubTypeSubject (:~:) f = ()
subtApply = gcastWith
- subtTrans = trans
subtFull _ = Refl
instance IsSubType Sparse where
type IsSubTypeSubject Sparse f = f ~ SMTy
subtApply = applySparse
-
- subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2)
- subtTrans _ SpAbsent = SpAbsent
- subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b)
- subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b)
- subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2)
- subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2)
- subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2)
- subtTrans SpScal SpScal = SpScal
-
subtFull = spDense
spDense :: SMTy t -> Sparse t t
@@ -95,6 +102,7 @@ isDense (SMTMaybe t) (SpMaybe s)
isDense (SMTArr _ t) (SpArr s)
| Just Refl <- isDense t s = Just Refl
| otherwise = Nothing
+isDense SMTArr{} SpArrIdx{} = Nothing
isDense (SMTScal _) SpScal = Just Refl
isAbsent :: Sparse t t' -> Bool
@@ -104,4 +112,5 @@ isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2
isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2
isAbsent (SpMaybe s) = isAbsent s
isAbsent (SpArr s) = isAbsent s
+isAbsent (SpArrIdx s) = isAbsent s
isAbsent SpScal = False
diff --git a/src/CHAD/AST/Types.hs b/src/CHAD/AST/Types.hs
index f0feb55..bec2201 100644
--- a/src/CHAD/AST/Types.hs
+++ b/src/CHAD/AST/Types.hs
@@ -77,6 +77,7 @@ data SMTy t where
SMTMaybe :: SMTy a -> SMTy (TMaybe a)
SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t)
SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t)
+ SMTData :: STy a -> SMTy a -- ^ inclusion of non-monoidal information
deriving instance Show (SMTy t)
instance GCompare SMTy where
@@ -215,3 +216,61 @@ unTup unpack (_ `SCons` list) tup =
type family InvTup core env where
InvTup core '[] = core
InvTup core (t : ts) = InvTup (TPair core t) ts
+
+class KnownScalTy t where knownScalTy :: SScalTy t
+instance KnownScalTy TI32 where knownScalTy = STI32
+instance KnownScalTy TI64 where knownScalTy = STI64
+instance KnownScalTy TF32 where knownScalTy = STF32
+instance KnownScalTy TF64 where knownScalTy = STF64
+instance KnownScalTy TBool where knownScalTy = STBool
+
+class KnownTy t where knownTy :: STy t
+instance KnownTy TNil where knownTy = STNil
+instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy
+instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy
+instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy
+instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy
+instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy
+instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy
+instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy
+
+class KnownMTy t where knownMTy :: SMTy t
+instance KnownMTy TNil where knownMTy = SMTNil
+instance (KnownMTy s, KnownMTy t) => KnownMTy (TPair s t) where knownMTy = SMTPair knownMTy knownMTy
+instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy
+instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy
+instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy
+instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy
+
+class KnownEnv env where knownEnv :: SList STy env
+instance KnownEnv '[] where knownEnv = SNil
+instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv
+
+styKnown :: STy t -> Dict (KnownTy t)
+styKnown STNil = Dict
+styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
+styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
+styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
+styKnown (STMaybe t) | Dict <- styKnown t = Dict
+styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict
+styKnown (STScal t) | Dict <- sscaltyKnown t = Dict
+styKnown (STAccum t) | Dict <- smtyKnown t = Dict
+
+smtyKnown :: SMTy t -> Dict (KnownMTy t)
+smtyKnown SMTNil = Dict
+smtyKnown (SMTPair a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict
+smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict
+smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict
+smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict
+smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict
+
+sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t)
+sscaltyKnown STI32 = Dict
+sscaltyKnown STI64 = Dict
+sscaltyKnown STF32 = Dict
+sscaltyKnown STF64 = Dict
+sscaltyKnown STBool = Dict
+
+envKnown :: SList STy env -> Dict (KnownEnv env)
+envKnown SNil = Dict
+envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict
diff --git a/src/CHAD/AST/UnMonoid.hs b/src/CHAD/AST/UnMonoid.hs
index d3cad25..1a66cdf 100644
--- a/src/CHAD/AST/UnMonoid.hs
+++ b/src/CHAD/AST/UnMonoid.hs
@@ -212,6 +212,11 @@ accumulateSparse topty topsp arg accum = case (topty, topsp) of
(EIdx ext (evar (IS IZ)) (EVar ext tn IZ))
(\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $
ENil ext
+ (SMTArr _ t, SpArrIdx s) ->
+ eunPair arg $ \w1 e1 e2 ->
+ elet (accumulateSparse t s e2
+ (\w prj idx val -> accum (w .> w1) (SAPArrIdx prj) (EPair ext (weakenExpr w e1) idx) val)) $
+ ENil ext
acPrjCompose
:: SAIDense dense
diff --git a/src/CHAD/Drev.hs b/src/CHAD/Drev.hs
index bfa964b..ee92782 100644
--- a/src/CHAD/Drev.hs
+++ b/src/CHAD/Drev.hs
@@ -1009,11 +1009,10 @@ drev des accumMap sd = \case
let smallE = unsafeWeakenWithSubenv usedSub e in
subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
case drev usedDes (VarMap.subMap subAccumUsed accumMap) sd smallE of { Ret e0 subtape _ sub e2 ->
- let subMergeUsed' = subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E subMergeUsed) in
Ret (collectBindings (desD1E des) subD1eUsed)
(subenvAll (desD1E usedDes))
(weakenExpr (wSinks (desD1E usedDes)) $ drevPrimal des e)
- (subenvCompose subMergeUsed' sub)
+ (subenvCompose (subenvD2E subMergeUsed) sub)
(letBinds (fst (weakenBindingsE (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $
weakenExpr
(autoWeak (#d (auto1 @sd)