aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/AST.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-27 21:30:17 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-27 21:30:17 +0100
commit20f7d7be13cd7869b338f98d1ab3fd33e8bbfb3e (patch)
treea21c90034a02cdeb7240563dbbab355e49622d0a /src/CHAD/AST.hs
parentae634c056b500a568b2d89b7f8e225404a2c0c62 (diff)
WIP user-specified custom typesuser-types
The big roadblock encountered is that accumulation wants addition of monoids to be elementwise float addition; this fundamentally clashes with the concept of a user type with a custom zero and plus.
Diffstat (limited to 'src/CHAD/AST.hs')
-rw-r--r--src/CHAD/AST.hs21
1 files changed, 20 insertions, 1 deletions
diff --git a/src/CHAD/AST.hs b/src/CHAD/AST.hs
index ce9eb20..eab9425 100644
--- a/src/CHAD/AST.hs
+++ b/src/CHAD/AST.hs
@@ -20,6 +20,7 @@ import Data.Functor.Const
import Data.Functor.Identity
import Data.Int (Int64)
import Data.Kind (Type)
+import Data.Proxy
import CHAD.Array
import CHAD.AST.Accum
@@ -137,6 +138,10 @@ data Expr x env t where
-- partiality
EError :: x a -> STy a -> String -> Expr x env a
+
+ -- user types
+ EUser :: x (TUser t) -> STy (TUser t) -> Expr x env (UserRep t) -> Expr x env (TUser t)
+ EUnUser :: x (UserRep t) -> Expr x env (TUser t) -> Expr x env (UserRep t)
deriving instance (forall ty. Show (x ty)) => Show (Expr x env t)
-- | A (well-typed, well-scoped) expression using De Bruijn indices. The full
@@ -271,6 +276,9 @@ typeOf = \case
EError _ t _ -> t
+ EUser _ t _ -> t
+ EUnUser _ e | STUser t <- typeOf e -> userRepTy t
+
extOf :: Expr x env t -> x t
extOf = \case
EVar x _ _ -> x
@@ -317,6 +325,8 @@ extOf = \case
EPlus x _ _ _ -> x
EOneHot x _ _ _ _ -> x
EError x _ _ -> x
+ EUser x _ _ -> x
+ EUnUser x _ -> x
mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t
mapExt f = runIdentity . travExt (Identity . f)
@@ -368,6 +378,8 @@ travExt f = \case
EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b
EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b
EError x t s -> EError <$> f x <*> pure t <*> pure s
+ EUser x t e -> EUser <$> f x <*> pure t <*> travExt f e
+ EUnUser x e -> EUnUser <$> f x <*> travExt f e
substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t
substInline repl =
@@ -430,8 +442,10 @@ subst' f w = \case
EZero x t e -> EZero x t (subst' f w e)
EDeepZero x t e -> EDeepZero x t (subst' f w e)
EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b)
- EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b)
+ EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b)
EError x t s -> EError x t s
+ EUser x t e -> EUser x t (subst' f w e)
+ EUnUser x e -> EUnUser x (subst' f w e)
where
sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
-> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t
@@ -458,6 +472,7 @@ instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy
instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy
instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy
instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy
+instance UserPrimalType t => KnownTy (TUser t) where knownTy = STUser Proxy
class KnownMTy t where knownMTy :: SMTy t
instance KnownMTy TNil where knownMTy = SMTNil
@@ -466,6 +481,7 @@ instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy
instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy
instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy
instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy
+instance UserMonoidType t => KnownMTy (TUser t) where knownMTy = SMTUser Proxy
class KnownEnv env where knownEnv :: SList STy env
instance KnownEnv '[] where knownEnv = SNil
@@ -480,6 +496,7 @@ styKnown (STMaybe t) | Dict <- styKnown t = Dict
styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict
styKnown (STScal t) | Dict <- sscaltyKnown t = Dict
styKnown (STAccum t) | Dict <- smtyKnown t = Dict
+styKnown (STUser _) = Dict
smtyKnown :: SMTy t -> Dict (KnownMTy t)
smtyKnown SMTNil = Dict
@@ -488,6 +505,7 @@ smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict
smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict
smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict
smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict
+smtyKnown (SMTUser _) = Dict
sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t)
sscaltyKnown STI32 = Dict
@@ -657,6 +675,7 @@ makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy t
go SMTMaybe{} _ = ENil ext
go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e
go SMTScal{} _ = ENil ext
+ go (SMTUser t) e = euserZeroInfo t (EUnUser ext e)
splitSparsePair
:: -- given a sparsity