diff options
Diffstat (limited to 'src/AST.hs')
-rw-r--r-- | src/AST.hs | 109 |
1 files changed, 33 insertions, 76 deletions
@@ -13,92 +13,19 @@ {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module AST (module AST, module AST.Weaken) where +module AST (module AST, module AST.Types, module AST.Weaken) where import Data.Functor.Const import Data.Kind (Type) -import Data.Int -import Data.Type.Equality import Array import AST.Env +import AST.Types import AST.Weaken +import CHAD.Types import Data -data Ty - = TNil - | TPair Ty Ty - | TEither Ty Ty - | TArr Nat Ty -- ^ rank, element type - | TScal ScalTy - | TAccum Ty - deriving (Show, Eq, Ord) - -data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool - deriving (Show, Eq, Ord) - -type STy :: Ty -> Type -data STy t where - STNil :: STy TNil - STPair :: STy a -> STy b -> STy (TPair a b) - STEither :: STy a -> STy b -> STy (TEither a b) - STArr :: SNat n -> STy t -> STy (TArr n t) - STScal :: SScalTy t -> STy (TScal t) - STAccum :: STy t -> STy (TAccum t) -deriving instance Show (STy t) - -instance TestEquality STy where - testEquality STNil STNil = Just Refl - testEquality (STPair a b) (STPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl - testEquality (STEither a b) (STEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl - testEquality (STArr a b) (STArr a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl - testEquality (STScal a) (STScal a') | Just Refl <- testEquality a a' = Just Refl - testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl - testEquality _ _ = Nothing - -data SScalTy t where - STI32 :: SScalTy TI32 - STI64 :: SScalTy TI64 - STF32 :: SScalTy TF32 - STF64 :: SScalTy TF64 - STBool :: SScalTy TBool -deriving instance Show (SScalTy t) - -instance TestEquality SScalTy where - testEquality STI32 STI32 = Just Refl - testEquality STI64 STI64 = Just Refl - testEquality STF32 STF32 = Just Refl - testEquality STF64 STF64 = Just Refl - testEquality STBool STBool = Just Refl - testEquality _ _ = Nothing - -scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t)) -scalRepIsShow STI32 = Dict -scalRepIsShow STI64 = Dict -scalRepIsShow STF32 = Dict -scalRepIsShow STF64 = Dict -scalRepIsShow STBool = Dict - -type TIx = TScal TI64 - -tIx :: STy TIx -tIx = STScal STI64 - -type family ScalRep t where - ScalRep TI32 = Int32 - ScalRep TI64 = Int64 - ScalRep TF32 = Float - ScalRep TF64 = Double - ScalRep TBool = Bool - -type family ScalIsNumeric t where - ScalIsNumeric TI32 = True - ScalIsNumeric TI64 = True - ScalIsNumeric TF32 = True - ScalIsNumeric TF64 = True - ScalIsNumeric TBool = False - -- | This index is flipped around from the usual direction: the smallest index -- is at the _heart_ of the nesting, not at the outside. The outermost layer -- indexes into the _outer_ dimension of the type @t@. This makes indices into @@ -107,6 +34,7 @@ type family AcIdx t i where AcIdx t Z = TNil AcIdx (TPair a b) (S i) = TEither (AcIdx a i) (AcIdx b i) AcIdx (TEither a b) (S i) = TEither (AcIdx a i) (AcIdx b i) + AcIdx (TMaybe t) (S i) = AcIdx t i AcIdx (TArr Z t) (S i) = AcIdx t i AcIdx (TArr (S n) t) (S i) = TPair TIx (AcIdx (TArr n t) i) @@ -114,12 +42,20 @@ type family AcVal t i where AcVal t Z = t AcVal (TPair a b) (S i) = TEither (AcVal a i) (AcVal b i) AcVal (TEither a b) (S i) = TEither (AcVal a i) (AcVal b i) + AcVal (TMaybe t) (S i) = AcVal t i AcVal (TArr Z t) (S i) = AcVal t i AcVal (TArr (S n) t) (S i) = AcVal (TArr n t) i -- General assumption: head of the list (whatever way it is associated) is the -- inner variable / inner array dimension. In pretty printing, the inner -- variable / inner dimension is printed on the _right_. +-- +-- Note that the 'EZero' and 'EPlus' constructs have typing that depend on the +-- type transformation of CHAD. Indeed, these constructors are created _by_ +-- CHAD, and are intended to be eliminated after simplification, so that the +-- input program as well as the output program do not contain these +-- constructors. +-- TODO: ensure this by a "stage" type parameter. type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type data Expr x env t where -- lambda calculus @@ -134,6 +70,9 @@ data Expr x env t where EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b) EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b) ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c + ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t) + EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t) + EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b -- array operations EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) @@ -157,6 +96,10 @@ data Expr x env t where EAccum :: SNat i -> Expr x env (AcIdx t i) -> Expr x env (AcVal t i) -> Expr x env (TAccum t) -> Expr x env TNil -- EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil + -- monoidal operations (to be desugared to regular operations after simplification) + EZero :: STy t -> Expr x env (D2 t) + EPlus :: STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t) + -- partiality EError :: STy a -> String -> Expr x env a deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) @@ -220,6 +163,9 @@ typeOf = \case EInl _ t2 e -> STEither (typeOf e) t2 EInr _ t1 e -> STEither t1 (typeOf e) ECase _ _ a _ -> typeOf a + ENothing _ t -> STMaybe t + EJust _ e -> STMaybe (typeOf e) + EMaybe _ e _ _ -> typeOf e EConstArr _ n t _ -> STArr n (STScal t) EBuild1 _ _ e -> STArr (SS SZ) (typeOf e) @@ -239,6 +185,9 @@ typeOf = \case EWith e1 e2 -> STPair (typeOf e2) (typeOf e1) EAccum _ _ _ _ -> STNil + EZero t -> d2 t + EPlus t _ _ -> d2 t + EError t _ -> t unSNat :: SNat n -> Nat @@ -250,6 +199,7 @@ unSTy = \case STNil -> TNil STPair a b -> TPair (unSTy a) (unSTy b) STEither a b -> TEither (unSTy a) (unSTy b) + STMaybe t -> TMaybe (unSTy t) STArr n t -> TArr (unSNat n) (unSTy t) STScal t -> TScal (unSScalTy t) STAccum t -> TAccum (unSTy t) @@ -288,6 +238,9 @@ subst' f w = \case EInl x t e -> EInl x t (subst' f w e) EInr x t e -> EInr x t (subst' f w e) ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b) + ENothing x t -> ENothing x t + EJust x e -> EJust x (subst' f w e) + EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e) EConstArr x n t a -> EConstArr x n t a EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b) EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) @@ -303,6 +256,8 @@ subst' f w = \case EOp x op e -> EOp x op (subst' f w e) EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3) + EZero t -> EZero t + EPlus t a b -> EPlus t (subst' f w a) (subst' f w b) EError t s -> EError t s where sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) @@ -339,6 +294,7 @@ class KnownTy t where knownTy :: STy t instance KnownTy TNil where knownTy = STNil instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy +instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy instance KnownTy t => KnownTy (TAccum t) where knownTy = STAccum knownTy @@ -351,6 +307,7 @@ styKnown :: STy t -> Dict (KnownTy t) styKnown STNil = Dict styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict +styKnown (STMaybe t) | Dict <- styKnown t = Dict styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict styKnown (STScal t) | Dict <- sscaltyKnown t = Dict styKnown (STAccum t) | Dict <- styKnown t = Dict |