diff options
Diffstat (limited to 'src/AST.hs')
-rw-r--r-- | src/AST.hs | 211 |
1 files changed, 211 insertions, 0 deletions
diff --git a/src/AST.hs b/src/AST.hs new file mode 100644 index 0000000..4d642ba --- /dev/null +++ b/src/AST.hs @@ -0,0 +1,211 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DeriveFoldable #-} +module AST where + +import Data.Kind (Type) +import Data.Int + + +data Nat = Z | S Nat + deriving (Show, Eq, Ord) + +data SNat n where + SZ :: SNat Z + SS :: SNat n -> SNat (S n) +deriving instance (Show (SNat n)) + +data Vec n t where + VNil :: Vec n t + (:<) :: t -> Vec n t -> Vec (S n) t +deriving instance Show t => Show (Vec n t) +deriving instance Functor (Vec n) +deriving instance Foldable (Vec n) + +data SList f l where + SNil :: SList f '[] + SCons :: f a -> SList f l -> SList f (a : l) +deriving instance (forall a. Show (f a)) => Show (SList f l) + +data Ty + = TNil + | TPair Ty Ty + | TArr Nat Ty -- ^ rank, element type + | TScal ScalTy + | TEVM [Ty] Ty + deriving (Show, Eq, Ord) + +data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool + deriving (Show, Eq, Ord) + +type STy :: Ty -> Type +data STy t where + STNil :: STy TNil + STPair :: STy a -> STy b -> STy (TPair a b) + STArr :: SNat n -> STy t -> STy (TArr n t) + STScal :: SScalTy t -> STy (TScal t) + STEVM :: SList STy env -> STy t -> STy (TEVM env t) +deriving instance Show (STy t) + +data SScalTy t where + STI32 :: SScalTy TI32 + STI64 :: SScalTy TI64 + STF32 :: SScalTy TF32 + STF64 :: SScalTy TF64 + STBool :: SScalTy TBool +deriving instance Show (SScalTy t) + +type TIx = TScal TI64 + +type Idx :: [Ty] -> Ty -> Type +data Idx env t where + IZ :: Idx (t : env) t + IS :: Idx env t -> Idx (a : env) t +deriving instance Show (Idx env t) + +type family ScalRep t where + ScalRep TI32 = Int32 + ScalRep TI64 = Int64 + ScalRep TF32 = Float + ScalRep TF64 = Double + ScalRep TBool = Bool + +type ConsN :: Nat -> a -> [a] -> [a] +type family ConsN n x l where + ConsN Z x l = l + ConsN (S n) x l = x : ConsN n x l + +type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type +data Expr x env t where + -- lambda calculus + EVar :: x t -> STy t -> Idx env t -> Expr x env t + ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t + + -- array operations + EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t) + EBuild :: x (TArr n t) -> SNat n -> Vec n (Expr x env TIx) -> Expr x (ConsN n TIx env) t -> Expr x env (TArr n t) + EFold1 :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) + + -- expression operations + EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) + EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) + EIdx :: x t -> Expr x env (TArr n t) -> Vec n (Expr x env TIx) -> Expr x env t + EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t + + -- EVM operations + EMOne :: Idx venv t -> Expr x env t -> Expr x env (TEVM venv TNil) +deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) + +type SOp :: Ty -> Ty -> Type +data SOp a t where + OAdd :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) + OMul :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) + ONeg :: SScalTy a -> SOp (TScal a) (TScal a) + OLt :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) + OLe :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) + OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) + ONot :: SOp (TScal TBool) (TScal TBool) +deriving instance Show (SOp a t) + +opt2 :: SOp a t -> STy t +opt2 = \case + OAdd t -> STScal t + OMul t -> STScal t + ONeg t -> STScal t + OLt _ -> STScal STBool + OLe _ -> STScal STBool + OEq _ -> STScal STBool + ONot -> STScal STBool + +typeOf :: Expr x env t -> STy t +typeOf = \case + EVar _ t _ -> t + ELet _ _ e -> typeOf e + EBuild1 _ _ e -> STArr (SS SZ) (typeOf e) + EBuild _ n _ e -> STArr n (typeOf e) + EFold1 _ _ e | STArr (SS n) t <- typeOf e -> STArr n t + + -- expression operations + EConst _ t _ -> STScal t + EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t + EIdx _ e _ | STArr _ t <- typeOf e -> t + EOp _ op _ -> opt2 op + + EMOne _ _ -> STEVM _ STNil + +unSNat :: SNat n -> Nat +unSNat SZ = Z +unSNat (SS n) = S (unSNat n) + +unSTy :: STy t -> Ty +unSTy = \case + STNil -> TNil + STPair a b -> TPair (unSTy a) (unSTy b) + STArr n t -> TArr (unSNat n) (unSTy t) + STScal t -> TScal (unSScalTy t) + STEVM l t -> TEVM (unSList l) (unSTy t) + +unSList :: SList STy env -> [Ty] +unSList SNil = [] +unSList (SCons t l) = unSTy t : unSList l + +unSScalTy :: SScalTy t -> ScalTy +unSScalTy = \case + STI32 -> TI32 + STI64 -> TI64 + STF32 -> TF32 + STF64 -> TF64 + STBool -> TBool + +fromNat :: Nat -> Int +fromNat Z = 0 +fromNat (S n) = succ (fromNat n) + +data env :> env' where + WId :: env :> env + WSink :: env :> (t : env) + WCopy :: env :> env' -> (t : env) :> (t : env') + WThen :: env1 :> env2 -> env2 :> env3 -> env1 :> env3 +deriving instance Show (env :> env') + +(.>) :: env2 :> env3 -> env1 :> env2 -> env1 :> env3 +(.>) = flip WThen + +infixr @> +(@>) :: env :> env' -> Idx env t -> Idx env' t +WId @> i = i +WSink @> i = IS i +WCopy _ @> IZ = IZ +WCopy w @> (IS i) = IS (w @> i) +WThen w1 w2 @> i = w2 @> w1 @> i + +weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t +weakenExpr w = \case + EVar x t i -> EVar x t (w @> i) + ELet x rhs body -> ELet x (weakenExpr w rhs) (weakenExpr (WCopy w) body) + EBuild1 x e1 e2 -> EBuild1 x (weakenExpr w e1) (weakenExpr (WCopy w) e2) + EBuild x n es e -> EBuild x n (weakenVec w es) (weakenExpr (wcopyN n w) e) + EFold1 x e1 e2 -> EFold1 x (weakenExpr (WCopy (WCopy w)) e1) (weakenExpr w e2) + EConst x t v -> EConst x t v + EIdx1 x e1 e2 -> EIdx1 x (weakenExpr w e1) (weakenExpr w e2) + EIdx x e1 es -> EIdx x (weakenExpr w e1) (weakenVec w es) + EOp x op e -> EOp x op (weakenExpr w e) + EMOne i e -> EMOne i (weakenExpr w e) + +wcopyN :: SNat n -> env :> env' -> ConsN n TIx env :> ConsN n TIx env' +wcopyN SZ w = w +wcopyN (SS n) w = WCopy (wcopyN n w) + +weakenVec :: (env :> env') -> Vec n (Expr x env TIx) -> Vec n (Expr x env' TIx) +weakenVec _ VNil = VNil +weakenVec w (e :< v) = weakenExpr w e :< weakenVec w v |