diff options
Diffstat (limited to 'src/AST.hs')
-rw-r--r-- | src/AST.hs | 84 |
1 files changed, 65 insertions, 19 deletions
@@ -16,6 +16,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE FlexibleInstances #-} module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where import Data.Functor.Const @@ -33,11 +34,9 @@ import Data -- inner variable / inner array dimension. In pretty printing, the inner -- variable / inner dimension is printed on the _right_. -- --- Note that the 'EZero' and 'EPlus' constructs have typing that depend on the --- type transformation of CHAD. Indeed, these constructors are created _by_ --- CHAD, and are intended to be eliminated after simplification, so that the --- input program as well as the output program do not contain these --- constructors. +-- All the monoid operations are unsupposed as the input to CHAD, and are +-- intended to be eliminated after simplification, so that the input program as +-- well as the output program do not contain these constructors. -- TODO: ensure this by a "stage" type parameter. type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type data Expr x env t where @@ -56,6 +55,10 @@ data Expr x env t where ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t) EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t) EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b + ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b) + ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b) + ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b) + ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c -- array operations EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) @@ -88,13 +91,13 @@ data Expr x env t where -> Expr x env t -- accumulation effect on monoids - EWith :: x (TPair a (D2 t)) -> STy t -> Expr x env (D2 t) -> Expr x (TAccum t : env) a -> Expr x env (TPair a (D2 t)) - EAccum :: x TNil -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (TAccum t) -> Expr x env TNil + EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) + EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env a -> Expr x env (TAccum t) -> Expr x env TNil -- monoidal operations (to be desugared to regular operations after simplification) - EZero :: x (D2 t) -> STy t -> Expr x env (D2 t) - EPlus :: x (D2 t) -> STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t) - EOneHot :: x (D2 t) -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (D2 t) + EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t + EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t + EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env a -> Expr x env t -- partiality EError :: x a -> STy a -> String -> Expr x env a @@ -184,6 +187,10 @@ typeOf = \case ENothing _ t -> STMaybe t EJust _ e -> STMaybe (typeOf e) EMaybe _ e _ _ -> typeOf e + ELNil _ t1 t2 -> STLEither t1 t2 + ELInl _ t2 e -> STLEither (typeOf e) t2 + ELInr _ t1 e -> STLEither t1 (typeOf e) + ELCase _ _ a _ _ -> typeOf a EConstArr _ n t _ -> STArr n (STScal t) EBuild _ n _ e -> STArr n (typeOf e) @@ -206,9 +213,9 @@ typeOf = \case EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1) EAccum _ _ _ _ _ _ -> STNil - EZero _ t -> d2 t - EPlus _ t _ _ -> d2 t - EOneHot _ t _ _ _ -> d2 t + EZero _ t _ -> fromSMTy t + EPlus _ t _ _ -> fromSMTy t + EOneHot _ t _ _ _ -> fromSMTy t EError _ t _ -> t @@ -226,6 +233,10 @@ extOf = \case ENothing x _ -> x EJust x _ -> x EMaybe x _ _ _ -> x + ELNil x _ _ -> x + ELInl x _ _ -> x + ELInr x _ _ -> x + ELCase x _ _ _ _ -> x EConstArr x _ _ _ -> x EBuild x _ _ _ -> x EFold1Inner x _ _ _ _ -> x @@ -243,7 +254,7 @@ extOf = \case ECustom x _ _ _ _ _ _ _ _ -> x EWith x _ _ _ -> x EAccum x _ _ _ _ _ -> x - EZero x _ -> x + EZero x _ _ -> x EPlus x _ _ _ -> x EOneHot x _ _ _ _ -> x EError x _ _ -> x @@ -262,6 +273,10 @@ mapExt f = \case ENothing x t -> ENothing (f x) t EJust x e -> EJust (f x) (mapExt f e) EMaybe x a b e -> EMaybe (f x) (mapExt f a) (mapExt f b) (mapExt f e) + ELNil x t1 t2 -> ELNil (f x) t1 t2 + ELInl x t e -> ELInl (f x) t (mapExt f e) + ELInr x t e -> ELInr (f x) t (mapExt f e) + ELCase x e a b c -> ELCase (f x) (mapExt f e) (mapExt f a) (mapExt f b) (mapExt f c) EConstArr x n t a -> EConstArr (f x) n t a EBuild x n a b -> EBuild (f x) n (mapExt f a) (mapExt f b) EFold1Inner x cm a b c -> EFold1Inner (f x) cm (mapExt f a) (mapExt f b) (mapExt f c) @@ -279,7 +294,7 @@ mapExt f = \case ECustom x s t p a b c e1 e2 -> ECustom (f x) s t p (mapExt f a) (mapExt f b) (mapExt f c) (mapExt f e1) (mapExt f e2) EWith x t e1 e2 -> EWith (f x) t (mapExt f e1) (mapExt f e2) EAccum x t p e1 e2 e3 -> EAccum (f x) t p (mapExt f e1) (mapExt f e2) (mapExt f e3) - EZero x t -> EZero (f x) t + EZero x t e -> EZero (f x) t (mapExt f e) EPlus x t a b -> EPlus (f x) t (mapExt f a) (mapExt f b) EOneHot x t p a b -> EOneHot (f x) t p (mapExt f a) (mapExt f b) EError x t s -> EError (f x) t s @@ -315,6 +330,10 @@ subst' f w = \case ENothing x t -> ENothing x t EJust x e -> EJust x (subst' f w e) EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e) + ELNil x t1 t2 -> ELNil x t1 t2 + ELInl x t e -> ELInl x t (subst' f w e) + ELInr x t e -> ELInr x t (subst' f w e) + ELCase x e a b c -> ELCase x (subst' f w e) (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' (sinkF f) (WCopy w) c) EConstArr x n t a -> EConstArr x n t a EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) @@ -332,9 +351,9 @@ subst' f w = \case ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) EAccum x t p e1 e2 e3 -> EAccum x t p (subst' f w e1) (subst' f w e2) (subst' f w e3) - EZero x t -> EZero x t + EZero x t e -> EZero x t (subst' f w e) EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b) - EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) + EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) EError x t s -> EError x t s where sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) @@ -360,7 +379,16 @@ instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEithe instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy -instance KnownTy t => KnownTy (TAccum t) where knownTy = STAccum knownTy +instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy +instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy + +class KnownMTy t where knownMTy :: SMTy t +instance KnownMTy TNil where knownMTy = SMTNil +instance (KnownMTy s, KnownMTy t) => KnownMTy (TPair s t) where knownMTy = SMTPair knownMTy knownMTy +instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy +instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy +instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy +instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy class KnownEnv env where knownEnv :: SList STy env instance KnownEnv '[] where knownEnv = SNil @@ -373,7 +401,16 @@ styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict styKnown (STMaybe t) | Dict <- styKnown t = Dict styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict styKnown (STScal t) | Dict <- sscaltyKnown t = Dict -styKnown (STAccum t) | Dict <- styKnown t = Dict +styKnown (STAccum t) | Dict <- smtyKnown t = Dict +styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict + +smtyKnown :: SMTy t -> Dict (KnownMTy t) +smtyKnown SMTNil = Dict +smtyKnown (SMTPair a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict +smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict +smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict +smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict +smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t) sscaltyKnown STI32 = Dict @@ -451,3 +488,12 @@ eshapeEmpty (SS n) e = (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)) (EConst ext STI64 0))) (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)))) + +ezeroD2 :: STy t -> Ex env (D2 t) +ezeroD2 t | Refl <- lemZeroInfoD2 t = EZero ext (d2M t) (ENil ext) + +-- eaccumD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (TAccum (D2 t)) -> Ex env TNil +-- eaccumD2 t p ei ev ea | Refl <- lemZeroInfoD2 t = EAccum ext (d2M t) (ENil ext) p ei ev ea + +-- eonehotD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (D2 t) +-- eonehotD2 t p ei ev | Refl <- lemZeroInfoD2 t = EOneHot ext (d2M t) (ENil ext) p ei ev |