diff options
Diffstat (limited to 'src/Language/AST.hs')
-rw-r--r-- | src/Language/AST.hs | 196 |
1 files changed, 101 insertions, 95 deletions
diff --git a/src/Language/AST.hs b/src/Language/AST.hs index f31f249..511723a 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -1,143 +1,149 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE TypeApplications #-} module Language.AST where -import Data.Proxy +import Data.Kind (Type) import Data.Type.Equality import GHC.OverloadedLabels -import GHC.TypeLits (symbolVal, KnownSymbol) +import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(Text)) import AST import Data -data SExpr t where +type NExpr :: [(Symbol, Ty)] -> Ty -> Type +data NExpr env t where -- lambda calculus - SEVar :: Var t -> SExpr t - SELet :: SExpr a -> Lambda a (SExpr t) -> SExpr t + NEVar :: Lookup name env ~ t => Var name t -> NExpr env t + NELet :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t -- base types - SEPair :: SExpr a -> SExpr b -> SExpr (TPair a b) - SEFst :: SExpr (TPair a b) -> SExpr a - SESnd :: SExpr (TPair a b) -> SExpr b - SENil :: SExpr TNil - SEInl :: STy b -> SExpr a -> SExpr (TEither a b) - SEInr :: STy a -> SExpr b -> SExpr (TEither a b) - SECase :: SExpr (TEither a b) -> Lambda a (SExpr c) -> Lambda b (SExpr c) -> SExpr c + NEPair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b) + NEFst :: NExpr env (TPair a b) -> NExpr env a + NESnd :: NExpr env (TPair a b) -> NExpr env b + NENil :: NExpr env TNil + NEInl :: STy b -> NExpr env a -> NExpr env (TEither a b) + NEInr :: STy a -> NExpr env b -> NExpr env (TEither a b) + NECase :: NExpr env (TEither a b) -> Var name1 a -> NExpr ('(name1, a) : env) c -> Var name2 b -> NExpr ('(name2, b) : env) c -> NExpr env c -- array operations - SEBuild1 :: SExpr TIx -> Lambda TIx (SExpr t) -> SExpr (TArr (S Z) t) - SEBuild :: SNat n -> SExpr (Tup (Replicate n TIx)) -> Lambda (Tup (Replicate n TIx)) (SExpr t) -> SExpr (TArr n t) - SEFold1 :: Lambda t (Lambda t (SExpr t)) -> SExpr (TArr (S n) t) -> SExpr (TArr n t) - SEUnit :: SExpr t -> SExpr (TArr Z t) + NEBuild1 :: NExpr env TIx -> Var name TIx -> NExpr ('(name, TIx) : env) t -> NExpr env (TArr (S Z) t) + NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t) + NEFold1 :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) + NEUnit :: NExpr env t -> NExpr env (TArr Z t) -- expression operations - SEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> SExpr (TScal t) - SEIdx0 :: SExpr (TArr Z t) -> SExpr t - SEIdx1 :: SExpr (TArr (S n) t) -> SExpr TIx -> SExpr (TArr n t) - SEIdx :: SNat n -> SExpr (TArr n t) -> SExpr (Tup (Replicate n TIx)) -> SExpr t - SEShape :: SExpr (TArr n t) -> SExpr (Tup (Replicate n TIx)) - SEOp :: SOp a t -> SExpr a -> SExpr t + NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t) + NEIdx0 :: NExpr env (TArr Z t) -> NExpr env t + NEIdx1 :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t) + NEIdx :: SNat n -> NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t + NEShape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) + NEOp :: SOp a t -> NExpr env a -> NExpr env t -- partiality - SEError :: STy a -> String -> SExpr a -deriving instance Show (SExpr t) + NEError :: STy a -> String -> NExpr env a +deriving instance Show (NExpr env t) -data Var a = Var (STy a) String - deriving (Show) +type family Lookup name env where + Lookup "_" _ = TypeError (Text "Attempt to use variable with name '_'") + Lookup name ('(name, t) : env) = t + Lookup name (_ : env) = Lookup name env -data Lambda a b = Lambda (Var a) b +data Var name t = Var (SSymbol name) (STy t) deriving (Show) -mkLambda :: KnownTy a => String -> (SExpr a -> f t) -> Lambda a (f t) -mkLambda name f = mkLambda' (Var knownTy name) f - -mkLambda' :: Var a -> (SExpr a -> f t) -> Lambda a (f t) -mkLambda' var f = Lambda var (f (SEVar var)) - -mkLambda2 :: (KnownTy a, KnownTy b) - => String -> String -> (SExpr a -> SExpr b -> f t) -> Lambda a (Lambda b (f t)) -mkLambda2 name1 name2 f = mkLambda2' (Var knownTy name1) (Var knownTy name2) f - -mkLambda2' :: Var a -> Var b -> (SExpr a -> SExpr b -> f t) -> Lambda a (Lambda b (f t)) -mkLambda2' var1 var2 f = Lambda var1 (Lambda var2 (f (SEVar var1) (SEVar var2))) - -instance (t ~ TScal st, KnownScalTy st, Num (ScalRep st)) => Num (SExpr t) where - a + b = SEOp (OAdd knownScalTy) (SEPair a b) - a * b = SEOp (OMul knownScalTy) (SEPair a b) - negate e = SEOp (ONeg knownScalTy) e - abs = error "abs undefined for SExpr" - signum = error "signum undefined for SExpr" +instance (t ~ TScal st, KnownScalTy st, Num (ScalRep st)) => Num (NExpr env t) where + a + b = NEOp (OAdd knownScalTy) (NEPair a b) + a * b = NEOp (OMul knownScalTy) (NEPair a b) + negate e = NEOp (ONeg knownScalTy) e + abs = error "abs undefined for NExpr" + signum = error "signum undefined for NExpr" fromInteger = let ty = knownScalTy in case scalRepIsShow ty of - Dict -> SEConst ty . fromInteger + Dict -> NEConst ty . fromInteger + +instance (KnownTy t, KnownSymbol name, name ~ n') => IsLabel name (Var n' t) where + fromLabel = Var symbolSing knownTy + +instance (KnownTy t, KnownSymbol name, Lookup name env ~ t) => IsLabel name (NExpr env t) where + fromLabel = NEVar (fromLabel @name) + +data NEnv env where + NTop :: NEnv '[] + NPush :: NEnv env -> Var name t -> NEnv ('(name, t) : env) -instance (KnownTy t, KnownSymbol name) => IsLabel name (Var t) where - fromLabel = Var knownTy (symbolVal (Proxy @name)) +data NFun env env' t where + NLam :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t + NBody :: NExpr env t -> NFun env env t -instance (KnownTy t, KnownSymbol name) => IsLabel name (SExpr t) where - fromLabel = SEVar (fromLabel @name) +type family UnName env where + UnName '[] = '[] + UnName ('(name, t) : env) = t : UnName env -data SFun args t = SFun (SList Var args) (SExpr t) +fromNamed :: NFun '[] env t -> Ex (UnName env) t +fromNamed = fromNamedFun NTop -scopeCheck :: SFun env t -> Ex env t -scopeCheck (SFun args e) = scopeCheckExpr args e +fromNamedFun :: NEnv env -> NFun env env' t -> Ex (UnName env') t +fromNamedFun env (NLam var fun) = fromNamedFun (env `NPush` var) fun +fromNamedFun env (NBody e) = fromNamedExpr env e -scopeCheckExpr :: forall env t. SList Var env -> SExpr t -> Ex env t -scopeCheckExpr val = \case - SEVar tag@(Var ty _) - | Just idx <- find tag val -> EVar ext ty idx +fromNamedExpr :: forall env t. NEnv env -> NExpr env t -> Ex (UnName env) t +fromNamedExpr val = \case + NEVar var@(Var _ ty) + | Just idx <- find var val -> EVar ext ty idx | otherwise -> error "Variable out of scope in conversion from surface \ \expression to De Bruijn expression" - SELet a b -> ELet ext (go a) (lambda val b) - - SEPair a b -> EPair ext (go a) (go b) - SEFst e -> EFst ext (go e) - SESnd e -> ESnd ext (go e) - SENil -> ENil ext - SEInl t e -> EInl ext t (go e) - SEInr t e -> EInr ext t (go e) - SECase e a b -> ECase ext (go e) (lambda val a) (lambda val b) - - SEBuild1 a b -> EBuild1 ext (go a) (lambda val b) - SEBuild n a b -> EBuild ext n (go a) (lambda val b) - SEFold1 a b -> EFold1 ext (lambda2 val a) (go b) - SEUnit e -> EUnit ext (go e) - - SEConst t x -> EConst ext t x - SEIdx0 e -> EIdx0 ext (go e) - SEIdx1 a b -> EIdx1 ext (go a) (go b) - SEIdx n a b -> EIdx ext n (go a) (go b) - SEShape e -> EShape ext (go e) - SEOp op e -> EOp ext op (go e) - - SEError t s -> EError t s + NELet n a b -> ELet ext (go a) (lambda val n b) + + NEPair a b -> EPair ext (go a) (go b) + NEFst e -> EFst ext (go e) + NESnd e -> ESnd ext (go e) + NENil -> ENil ext + NEInl t e -> EInl ext t (go e) + NEInr t e -> EInr ext t (go e) + NECase e n1 a n2 b -> ECase ext (go e) (lambda val n1 a) (lambda val n2 b) + + NEBuild1 a n b -> EBuild1 ext (go a) (lambda val n b) + NEBuild k a n b -> EBuild ext k (go a) (lambda val n b) + NEFold1 n1 n2 a b -> EFold1 ext (lambda2 val n1 n2 a) (go b) + NEUnit e -> EUnit ext (go e) + + NEConst t x -> EConst ext t x + NEIdx0 e -> EIdx0 ext (go e) + NEIdx1 a b -> EIdx1 ext (go a) (go b) + NEIdx n a b -> EIdx ext n (go a) (go b) + NEShape e -> EShape ext (go e) + NEOp op e -> EOp ext op (go e) + + NEError t s -> EError t s where - go :: SExpr t' -> Ex env t' - go = scopeCheckExpr val + go :: NExpr env t' -> Ex (UnName env) t' + go = fromNamedExpr val - find :: Var t' -> SList Var env' -> Maybe (Idx env' t') - find _ SNil = Nothing - find tag@(Var ty s) (Var ty' s' `SCons` val') - | s == s' + find :: Var name t' -> NEnv env' -> Maybe (Idx (UnName env') t') + find _ NTop = Nothing + find var@(Var s ty) (val' `NPush` Var s' ty') + | Just Refl <- testEquality s s' , Just Refl <- testEquality ty ty' = Just IZ | otherwise - = IS <$> find tag val' + = IS <$> find var val' - lambda :: SList Var env' -> Lambda a (SExpr b) -> Ex (a : env') b - lambda val' (Lambda tag e) = scopeCheckExpr (tag `SCons` val') e + lambda :: NEnv env' -> Var name a -> NExpr ('(name, a) : env') b -> Ex (a : UnName env') b + lambda val' var e = fromNamedExpr (val' `NPush` var) e - lambda2 :: SList Var env' -> Lambda a (Lambda b (SExpr c)) -> Ex (a : b : env') c - lambda2 val' (Lambda tag (Lambda tag' e)) = scopeCheckExpr (tag `SCons` tag' `SCons` val') e + lambda2 :: NEnv env' -> Var name1 a -> Var name2 b -> NExpr ('(name2, b) : '(name1, a) : env') c -> Ex (b : a : UnName env') c + lambda2 val' var1 var2 e = fromNamedExpr (val' `NPush` var1 `NPush` var2) e |