diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2026-01-26 23:37:55 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2026-01-26 23:37:55 +0100 |
| commit | b4f988cb1490ed31ab225323b33448667b8578c0 (patch) | |
| tree | d048e70e33f2e2787aae68a9b671b78094c05c43 /src/CHAD/AST.hs | |
| parent | a9e6c72eff3bee8d45e0d906e8cd027066e04793 (diff) | |
Multihot cotangents WIP (doesn't work)multihot-cotangents
The idea is sound but for a smaller source language. Notes also in
Obsidian, but the theory so far is that dropping support for nested
arrays makes this possible, although making the result type-safe (i.e.
not have partial functions in a bunch of places) would require making
the lack of nested array support explicit in the embedded type system,
i.e. have Accelerate-like stratification.
The point is that multihots can be added heterogeneously using
plusSparseS but not homogeneously with EPlus or plusSparse, because the
indices might differ between the summands. Thus as long as we never need
to homogeneously sum multihot cotangents, we're golden.
Now the crucial observation is that we only need plus to be homogeneous
on array elements. So if array elements cannot themselves be arrays,
i.e. we drop support for nested arrays, no homogeneous plus of multihot
array cotangents is needed, and we can have static multihots.
Diffstat (limited to 'src/CHAD/AST.hs')
| -rw-r--r-- | src/CHAD/AST.hs | 58 |
1 files changed, 0 insertions, 58 deletions
diff --git a/src/CHAD/AST.hs b/src/CHAD/AST.hs index b795070..3f6dfc4 100644 --- a/src/CHAD/AST.hs +++ b/src/CHAD/AST.hs @@ -442,64 +442,6 @@ subst' f w = \case weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i)) -class KnownScalTy t where knownScalTy :: SScalTy t -instance KnownScalTy TI32 where knownScalTy = STI32 -instance KnownScalTy TI64 where knownScalTy = STI64 -instance KnownScalTy TF32 where knownScalTy = STF32 -instance KnownScalTy TF64 where knownScalTy = STF64 -instance KnownScalTy TBool where knownScalTy = STBool - -class KnownTy t where knownTy :: STy t -instance KnownTy TNil where knownTy = STNil -instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy -instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy -instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy -instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy -instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy -instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy -instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy - -class KnownMTy t where knownMTy :: SMTy t -instance KnownMTy TNil where knownMTy = SMTNil -instance (KnownMTy s, KnownMTy t) => KnownMTy (TPair s t) where knownMTy = SMTPair knownMTy knownMTy -instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy -instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy -instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy -instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy - -class KnownEnv env where knownEnv :: SList STy env -instance KnownEnv '[] where knownEnv = SNil -instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv - -styKnown :: STy t -> Dict (KnownTy t) -styKnown STNil = Dict -styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict -styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict -styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict -styKnown (STMaybe t) | Dict <- styKnown t = Dict -styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict -styKnown (STScal t) | Dict <- sscaltyKnown t = Dict -styKnown (STAccum t) | Dict <- smtyKnown t = Dict - -smtyKnown :: SMTy t -> Dict (KnownMTy t) -smtyKnown SMTNil = Dict -smtyKnown (SMTPair a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict -smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict -smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict -smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict -smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict - -sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t) -sscaltyKnown STI32 = Dict -sscaltyKnown STI64 = Dict -sscaltyKnown STF32 = Dict -sscaltyKnown STF64 = Dict -sscaltyKnown STBool = Dict - -envKnown :: SList STy env -> Dict (KnownEnv env) -envKnown SNil = Dict -envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict - cheapExpr :: Expr x env t -> Bool cheapExpr = \case EVar{} -> True |
