diff options
Diffstat (limited to 'src/CHAD/AST.hs')
| -rw-r--r-- | src/CHAD/AST.hs | 21 |
1 files changed, 20 insertions, 1 deletions
diff --git a/src/CHAD/AST.hs b/src/CHAD/AST.hs index ce9eb20..eab9425 100644 --- a/src/CHAD/AST.hs +++ b/src/CHAD/AST.hs @@ -20,6 +20,7 @@ import Data.Functor.Const import Data.Functor.Identity import Data.Int (Int64) import Data.Kind (Type) +import Data.Proxy import CHAD.Array import CHAD.AST.Accum @@ -137,6 +138,10 @@ data Expr x env t where -- partiality EError :: x a -> STy a -> String -> Expr x env a + + -- user types + EUser :: x (TUser t) -> STy (TUser t) -> Expr x env (UserRep t) -> Expr x env (TUser t) + EUnUser :: x (UserRep t) -> Expr x env (TUser t) -> Expr x env (UserRep t) deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) -- | A (well-typed, well-scoped) expression using De Bruijn indices. The full @@ -271,6 +276,9 @@ typeOf = \case EError _ t _ -> t + EUser _ t _ -> t + EUnUser _ e | STUser t <- typeOf e -> userRepTy t + extOf :: Expr x env t -> x t extOf = \case EVar x _ _ -> x @@ -317,6 +325,8 @@ extOf = \case EPlus x _ _ _ -> x EOneHot x _ _ _ _ -> x EError x _ _ -> x + EUser x _ _ -> x + EUnUser x _ -> x mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t mapExt f = runIdentity . travExt (Identity . f) @@ -368,6 +378,8 @@ travExt f = \case EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b EError x t s -> EError <$> f x <*> pure t <*> pure s + EUser x t e -> EUser <$> f x <*> pure t <*> travExt f e + EUnUser x e -> EUnUser <$> f x <*> travExt f e substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t substInline repl = @@ -430,8 +442,10 @@ subst' f w = \case EZero x t e -> EZero x t (subst' f w e) EDeepZero x t e -> EDeepZero x t (subst' f w e) EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b) - EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) + EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) EError x t s -> EError x t s + EUser x t e -> EUser x t (subst' f w e) + EUnUser x e -> EUnUser x (subst' f w e) where sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t @@ -458,6 +472,7 @@ instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy +instance UserPrimalType t => KnownTy (TUser t) where knownTy = STUser Proxy class KnownMTy t where knownMTy :: SMTy t instance KnownMTy TNil where knownMTy = SMTNil @@ -466,6 +481,7 @@ instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy +instance UserMonoidType t => KnownMTy (TUser t) where knownMTy = SMTUser Proxy class KnownEnv env where knownEnv :: SList STy env instance KnownEnv '[] where knownEnv = SNil @@ -480,6 +496,7 @@ styKnown (STMaybe t) | Dict <- styKnown t = Dict styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict styKnown (STScal t) | Dict <- sscaltyKnown t = Dict styKnown (STAccum t) | Dict <- smtyKnown t = Dict +styKnown (STUser _) = Dict smtyKnown :: SMTy t -> Dict (KnownMTy t) smtyKnown SMTNil = Dict @@ -488,6 +505,7 @@ smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict +smtyKnown (SMTUser _) = Dict sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t) sscaltyKnown STI32 = Dict @@ -657,6 +675,7 @@ makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy t go SMTMaybe{} _ = ENil ext go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e go SMTScal{} _ = ENil ext + go (SMTUser t) e = euserZeroInfo t (EUnUser ext e) splitSparsePair :: -- given a sparsity |
