From 342c213f3caddd64db0eac5ae146912e00378371 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 26 Jul 2020 23:02:09 +0200 Subject: WIP refactor and union types, type variables --- ast/CC/AST/Source.hs | 37 ++++++-- ast/CC/AST/Typed.hs | 48 ++++++++-- ast/CC/Context.hs | 2 +- backend/CC/Backend/Dumb.hs | 24 +++-- compcomp.cabal | 3 +- parser/CC/Parser.hs | 97 +++++++++++++++----- typecheck/CC/Typecheck.hs | 182 +++++++++++++++++++++++-------------- typecheck/CC/Typecheck/Typedefs.hs | 50 ++++++++++ typecheck/CC/Typecheck/Types.hs | 103 +++++++++++++++++++++ 9 files changed, 435 insertions(+), 111 deletions(-) create mode 100644 typecheck/CC/Typecheck/Typedefs.hs create mode 100644 typecheck/CC/Typecheck/Types.hs diff --git a/ast/CC/AST/Source.hs b/ast/CC/AST/Source.hs index e648759..e64e058 100644 --- a/ast/CC/AST/Source.hs +++ b/ast/CC/AST/Source.hs @@ -3,6 +3,8 @@ module CC.AST.Source( module CC.Types ) where +import qualified Data.Set as Set +import Data.Set (Set) import Data.List import CC.Pretty @@ -12,19 +14,32 @@ import CC.Types data Program = Program [Decl] deriving (Show, Read) -data Decl = Def Def -- import? +data Decl = DeclFunc FuncDef + | DeclType TypeDef + | DeclAlias AliasDef deriving (Show, Read) -data Def = Function (Maybe Type) - (Name, SourceRange) - [(Name, SourceRange)] - Expr +data FuncDef = + FuncDef (Maybe Type) + (Name, SourceRange) + [(Name, SourceRange)] + Expr + deriving (Show, Read) + +-- Named type with named arguments +data TypeDef = TypeDef (Name, SourceRange) [(Name, SourceRange)] Type + deriving (Show, Read) + +data AliasDef = AliasDef (Name, SourceRange) [(Name, SourceRange)] Type deriving (Show, Read) data Type = TFun Type Type | TInt | TTup [Type] - deriving (Show, Read) + | TNamed Name [Type] -- named type with type arguments + | TUnion (Set Type) + | TyVar Name + deriving (Eq, Ord, Show, Read) data Expr = Lam SourceRange [(Name, SourceRange)] Expr | Let SourceRange (Name, SourceRange) Expr Expr @@ -32,15 +47,24 @@ data Expr = Lam SourceRange [(Name, SourceRange)] Expr | Int SourceRange Int | Tup SourceRange [Expr] | Var SourceRange Name + | Constr SourceRange Name -- type constructor | Annot SourceRange Expr Type deriving (Show, Read) +instance Semigroup Program where Program p1 <> Program p2 = Program (p1 <> p2) +instance Monoid Program where mempty = Program mempty + instance Pretty Type where prettyPrec _ TInt = "Int" prettyPrec p (TFun a b) = precParens p 2 (prettyPrec 3 a ++ " -> " ++ prettyPrec 2 b) prettyPrec _ (TTup ts) = "(" ++ intercalate ", " (map pretty ts) ++ ")" + prettyPrec _ (TNamed n ts) = + n ++ "[" ++ intercalate ", " (map pretty ts) ++ "]" + prettyPrec _ (TUnion ts) = + "{" ++ intercalate " | " (map pretty (Set.toList ts)) ++ "}" + prettyPrec _ (TyVar n) = "<" ++ n ++ ">" instance HasRange Expr where range (Lam sr _ _) = sr @@ -49,4 +73,5 @@ instance HasRange Expr where range (Int sr _) = sr range (Tup sr _) = sr range (Var sr _) = sr + range (Constr sr _) = sr range (Annot sr _ _) = sr diff --git a/ast/CC/AST/Typed.hs b/ast/CC/AST/Typed.hs index b12b30a..cf67575 100644 --- a/ast/CC/AST/Typed.hs +++ b/ast/CC/AST/Typed.hs @@ -3,23 +3,36 @@ module CC.AST.Typed( module CC.Types ) where +import qualified Data.Map.Strict as Map +import Data.Map.Strict (Map) +import qualified Data.Set as Set +import Data.Set (Set) import Data.List import CC.Pretty import CC.Types -data Program = Program [Def] +data Program = Program [Def] (Map Name TypeDef) deriving (Show, Read) data Def = Def Name Expr deriving (Show, Read) +-- Named type with type parameters +data TypeDef = TypeDef Name [Int] Type + deriving (Show, Read) + data Type = TFun Type Type | TInt | TTup [Type] - | TyVar Int - deriving (Show, Read) + | TNamed Name [Type] -- named type with type arguments + | TUnion (Set Type) + | TyVar Rigidity Int + deriving (Eq, Ord, Show, Read) + +data Rigidity = Rigid | Instantiable + deriving (Eq, Ord, Show, Read) data TypeScheme = TypeScheme [Int] Type deriving (Show, Read) @@ -30,6 +43,7 @@ data Expr = Lam Type Occ Expr | Int Int | Tup [Expr] | Var Occ + | Constr Type Name -- Type is 'argument -> TNamed' deriving (Show, Read) data Occ = Occ Name Type @@ -42,6 +56,7 @@ exprType (Call typ _ _) = typ exprType (Int _) = TInt exprType (Tup es) = TTup (map exprType es) exprType (Var (Occ _ typ)) = typ +exprType (Constr typ _) = typ instance Pretty Type where prettyPrec _ TInt = "Int" @@ -49,13 +64,18 @@ instance Pretty Type where precParens p 2 (prettyPrec 3 a ++ " -> " ++ prettyPrec 2 b) prettyPrec _ (TTup ts) = "(" ++ intercalate ", " (map pretty ts) ++ ")" - prettyPrec _ (TyVar i) = 't' : show i + prettyPrec _ (TNamed n ts) = + n ++ "[" ++ intercalate ", " (map pretty ts) ++ "]" + prettyPrec _ (TUnion ts) = + "{ " ++ intercalate " | " (map pretty (Set.toList ts)) ++ " }" + prettyPrec _ (TyVar Rigid i) = 't' : show i ++ "R" + prettyPrec _ (TyVar Instantiable i) = 't' : show i instance Pretty TypeScheme where prettyPrec p (TypeScheme bnds ty) = precParens p 2 - ("forall " ++ intercalate " " (map (pretty . TyVar) bnds) ++ ". " ++ - prettyPrec 2 ty) + ("forall " ++ intercalate " " (map (pretty . TyVar Instantiable) bnds) ++ + ". " ++ prettyPrec 2 ty) instance Pretty Expr where prettyPrec p (Lam ty (Occ n t) e) = @@ -75,10 +95,22 @@ instance Pretty Expr where prettyPrec _ (Tup es) = "(" ++ intercalate ", " (map pretty es) ++ ")" prettyPrec p (Var (Occ n t)) = precParens p 2 $ - show n ++ " :: " ++ pretty t + n ++ " :: " ++ pretty t + prettyPrec p (Constr t n) = + precParens p 2 $ + n ++ " :: " ++ pretty t instance Pretty Def where pretty (Def n e) = n ++ " = " ++ pretty e +instance Pretty TypeDef where + pretty (TypeDef n [] t) = + "type " ++ n ++ " = " ++ pretty t + pretty (TypeDef n vs t) = + "type " ++ n ++ " " ++ intercalate " " (map (pretty . TyVar Instantiable) vs) ++ + " = " ++ pretty t + instance Pretty Program where - pretty (Program defs) = concatMap ((++ "\n") . pretty) defs + pretty (Program defs tdefs) = + concatMap (++ "\n") $ + map pretty (Map.elems tdefs) ++ map pretty defs diff --git a/ast/CC/Context.hs b/ast/CC/Context.hs index 68378d7..acfb614 100644 --- a/ast/CC/Context.hs +++ b/ast/CC/Context.hs @@ -9,4 +9,4 @@ import CC.AST.Typed data Context = Context FilePath Builtins -- | Information about builtins supported by the enabled backend -data Builtins = Builtins (Map Name Type) +data Builtins = Builtins (Map Name Type) String diff --git a/backend/CC/Backend/Dumb.hs b/backend/CC/Backend/Dumb.hs index 3210dab..822fb84 100644 --- a/backend/CC/Backend/Dumb.hs +++ b/backend/CC/Backend/Dumb.hs @@ -7,10 +7,20 @@ import CC.Context builtins :: Builtins -builtins = Builtins . Map.fromList $ - [ ("print", TFun TInt (TTup [])) - , ("fst", TFun (TTup [TyVar 1, TyVar 2]) (TyVar 1)) - , ("snd", TFun (TTup [TyVar 1, TyVar 2]) (TyVar 2)) - , ("_add", TFun TInt (TFun TInt TInt)) - , ("_sub", TFun TInt (TFun TInt TInt)) - , ("_mul", TFun TInt (TFun TInt TInt)) ] +builtins = + let values = Map.fromList + [ ("print", TInt ==> TTup []) + , ("fst", TTup [t1, t2] ==> t1) + , ("snd", TTup [t1, t2] ==> t2) + , ("_add", TInt ==> TInt ==> TInt) + , ("_sub", TInt ==> TInt ==> TInt) + , ("_mul", TInt ==> TInt ==> TInt) ] + prelude = "type Nil = ()\n\ + \type Cons a = (a, List a)\n\ + \alias List a = { Nil | Cons a }\n" + in Builtins values prelude + where + t1 = TyVar Instantiable 1 + t2 = TyVar Instantiable 2 + infixr ==> + (==>) = TFun diff --git a/compcomp.cabal b/compcomp.cabal index 9e941aa..4b06f4b 100644 --- a/compcomp.cabal +++ b/compcomp.cabal @@ -20,7 +20,7 @@ executable compcomp library cc-parser import: deps hs-source-dirs: parser - build-depends: parsec, cc-ast, cc-utils + build-depends: containers, parsec, cc-ast, cc-utils exposed-modules: CC.Parser library cc-typecheck @@ -28,6 +28,7 @@ library cc-typecheck hs-source-dirs: typecheck build-depends: containers, mtl, cc-ast, cc-utils exposed-modules: CC.Typecheck + other-modules: CC.Typecheck.Typedefs, CC.Typecheck.Types library cc-backend-dumb import: deps diff --git a/parser/CC/Parser.hs b/parser/CC/Parser.hs index 2d2c4b7..66bc6cf 100644 --- a/parser/CC/Parser.hs +++ b/parser/CC/Parser.hs @@ -1,6 +1,7 @@ module CC.Parser(runPass, parseProgram) where import Control.Monad +import qualified Data.Set as Set import Text.Parsec hiding (SourcePos, getPosition, token) import qualified Text.Parsec @@ -12,7 +13,10 @@ import CC.Pretty type Parser a = Parsec String () a runPass :: Context -> RawString -> Either (PrettyShow ParseError) Program -runPass (Context path _) (RawString src) = fmapLeft PrettyShow (parseProgram path src) +runPass (Context path (Builtins _ prelude)) (RawString src) = do + prog1 <- fmapLeft PrettyShow (parseProgram "" prelude) + prog2 <- fmapLeft PrettyShow (parseProgram path src) + return (prog1 <> prog2) where fmapLeft f (Left x) = Left (f x) fmapLeft _ (Right x) = Right x @@ -27,32 +31,63 @@ pProgram = do return prog pDecl :: Parser Decl -pDecl = Def <$> pDef +pDecl = choice + [ DeclType <$> pDeclType + , DeclAlias <$> pDeclAlias + , DeclFunc <$> pDeclFunc ] -pDef :: Parser Def -pDef = do +pDeclFunc :: Parser FuncDef +pDeclFunc = do func <- try $ do emptyLines - name <- pName0 "declaration head name" + name <- pName0 LowerCase "declaration head name" return name mtyp <- optionMaybe $ do symbol "::" typ <- pType whitespace >> void newline emptyLines - func' <- fst <$> pName0 + (func', _) <- pName0 LowerCase guard (fst func == func') return typ - args <- many pName + args <- many (pName LowerCase) symbol "=" expr <- pExpr - return (Function mtyp func args expr) + return (FuncDef mtyp func args expr) + +pDeclType :: Parser TypeDef +pDeclType = (\(n, a, t) -> TypeDef n a t) <$> pTypedefLike "type" + +pDeclAlias :: Parser AliasDef +pDeclAlias = (\(n, a, t) -> AliasDef n a t) <$> pTypedefLike "alias" + +pTypedefLike :: String -> Parser ((Name, SourceRange), [(Name, SourceRange)], Type) +pTypedefLike keyword = do + try (emptyLines >> string keyword >> whitespace1) + name <- pName0 UpperCase + args <- many (pName LowerCase) + symbol "=" + ty <- pType + return (name, args, ty) pType :: Parser Type -pType = chainr1 pTypeAtom (symbol "->" >> return TFun) +pType = chainr1 pTypeTerm (symbol "->" >> return TFun) + +pTypeTerm :: Parser Type +pTypeTerm = pTypeAtom <|> pTypeCall pTypeAtom :: Parser Type -pTypeAtom = (wordToken "Int" >> return TInt) <|> pParenType +pTypeAtom = choice + [ wordToken "Int" >> return TInt + , TyVar . fst <$> pName LowerCase + , pParenType + , pUnionType ] + +pTypeCall :: Parser Type +pTypeCall = do + (constr, _) <- pName UpperCase + args <- many pTypeAtom + return (TNamed constr args) pParenType :: Parser Type pParenType = do @@ -63,6 +98,19 @@ pParenType = do [ty] -> return ty _ -> return (TTup tys) +pUnionType :: Parser Type +pUnionType = do + token "{" + tys <- pType `sepBy` token "|" + token "}" + case tys of + [] -> unexpected "empty union type" + [ty] -> return ty + _ -> let tyset = Set.fromList tys + in if Set.size tyset == length tys + then return (TUnion tyset) + else unexpected "duplicate types in union" + pExpr :: Parser Expr pExpr = label (pLam <|> pLet <|> pCall) "expression" where @@ -84,7 +132,7 @@ pExpr = label (pLam <|> pLet <|> pCall) "expression" p <- getPosition void (char '\\') return p - names <- many1 pName + names <- many1 (pName LowerCase) symbol "->" body <- pExpr p2 <- getPosition @@ -100,7 +148,7 @@ pExpr = label (pLam <|> pLet <|> pCall) "expression" where afterKeyword p1 = do whitespace1 - lhs <- pName0 + lhs <- pName0 LowerCase symbol "=" rhs <- pExpr let fullRange rest = mergeRange (SourceRange p1 p1) (range rest) @@ -118,7 +166,8 @@ pExpr = label (pLam <|> pLet <|> pCall) "expression" pExprAtom :: Parser Expr pExprAtom = choice [ uncurry (flip Int) <$> pInt - , uncurry (flip Var) <$> pName + , uncurry (flip Var) <$> pName LowerCase + , uncurry (flip Constr) <$> pName UpperCase , pParenExpr ] pParenExpr :: Parser Expr @@ -141,11 +190,13 @@ pInt = try (whitespace >> pInt0) p2 <- getPosition return (num, SourceRange p1 p2) -pName0 :: Parser (Name, SourceRange) -pName0 = do +data Case = LowerCase | UpperCase + +pName0 :: Case -> Parser (Name, SourceRange) +pName0 vcase = do p1 <- getPosition s <- try $ do - c <- pWordFirstChar + c <- pWordFirstChar vcase cs <- many pWordMidChar let s = c : cs guard (s `notElem` ["let", "in"]) @@ -154,14 +205,18 @@ pName0 = do notFollowedBy pWordMidChar return (s, SourceRange p1 p2) -pWordFirstChar :: Parser Char -pWordFirstChar = letter <|> oneOf "_$#!" +pWordFirstChar :: Case -> Parser Char +pWordFirstChar LowerCase = lower <|> oneOf wordSymbols +pWordFirstChar UpperCase = upper <|> oneOf wordSymbols pWordMidChar :: Parser Char -pWordMidChar = alphaNum <|> oneOf "_$#!" +pWordMidChar = alphaNum <|> oneOf wordSymbols + +wordSymbols :: [Char] +wordSymbols = "_$#!" -pName :: Parser (Name, SourceRange) -pName = try (whitespace >> pName0) +pName :: Case -> Parser (Name, SourceRange) +pName vcase = try (whitespace >> pName0 vcase) symbol :: String -> Parser () symbol s = token s >> (eof <|> void space <|> void (oneOf "(){}[]")) diff --git a/typecheck/CC/Typecheck.hs b/typecheck/CC/Typecheck.hs index f61103e..824a714 100644 --- a/typecheck/CC/Typecheck.hs +++ b/typecheck/CC/Typecheck.hs @@ -1,4 +1,5 @@ {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeSynonymInstances #-} module CC.Typecheck(runPass) where @@ -10,59 +11,47 @@ import Data.Maybe import qualified Data.Set as Set import Data.Set (Set) +import Debug.Trace + import qualified CC.AST.Source as S import qualified CC.AST.Typed as T import CC.Context -import CC.Pretty import CC.Types +import CC.Typecheck.Typedefs +import CC.Typecheck.Types -- Inspiration: https://github.com/kritzcreek/fby19 -data TCError = TypeError SourceRange T.Type T.Type - | RefError SourceRange Name +data Env = + Env (Map Name T.TypeScheme) -- Definitions in scope + (Map Name T.TypeDef) -- Type definitions + (Map Name S.AliasDef) -- Type aliases deriving (Show) -instance Pretty TCError where - pretty (TypeError sr real expect) = - "Type error: Expression at " ++ pretty sr ++ - " has type " ++ pretty real ++ - ", but should have type " ++ pretty expect - pretty (RefError sr name) = - "Reference error: Variable '" ++ name ++ "' out of scope at " ++ pretty sr - -type TM a = ExceptT TCError (State Int) a - -genId :: TM Int -genId = state (\idval -> (idval, idval + 1)) - -genTyVar :: TM T.Type -genTyVar = T.TyVar <$> genId - -runTM :: TM a -> Either TCError a -runTM m = evalState (runExceptT m) 1 - - -newtype Env = Env (Map Name T.TypeScheme) - newtype Subst = Subst (Map Int T.Type) class FreeTypeVars a where - freeTypeVars :: a -> Set Int + -- Free instantiable type variables + freeInstTypeVars :: a -> Set Int instance FreeTypeVars T.Type where - freeTypeVars (T.TFun t1 t2) = freeTypeVars t1 <> freeTypeVars t2 - freeTypeVars T.TInt = mempty - freeTypeVars (T.TTup ts) = Set.unions (map freeTypeVars ts) - freeTypeVars (T.TyVar var) = Set.singleton var + freeInstTypeVars (T.TFun t1 t2) = freeInstTypeVars t1 <> freeInstTypeVars t2 + freeInstTypeVars T.TInt = mempty + freeInstTypeVars (T.TTup ts) = Set.unions (map freeInstTypeVars ts) + freeInstTypeVars (T.TNamed _ ts) = Set.unions (map freeInstTypeVars ts) + freeInstTypeVars (T.TUnion ts) = Set.unions (map freeInstTypeVars (Set.toList ts)) + freeInstTypeVars (T.TyVar T.Instantiable var) = Set.singleton var + freeInstTypeVars (T.TyVar T.Rigid _) = mempty instance FreeTypeVars T.TypeScheme where - freeTypeVars (T.TypeScheme bnds ty) = foldr Set.delete (freeTypeVars ty) bnds + freeInstTypeVars (T.TypeScheme bnds ty) = + foldr Set.delete (freeInstTypeVars ty) bnds instance FreeTypeVars Env where - freeTypeVars (Env mp) = foldMap freeTypeVars (Map.elems mp) + freeInstTypeVars (Env mp _ _) = foldMap freeInstTypeVars (Map.elems mp) infixr >>! @@ -74,14 +63,20 @@ instance Substitute T.Type where T.TFun t1 t2 -> T.TFun (theta >>! t1) (theta >>! t2) T.TInt -> T.TInt T.TTup ts -> T.TTup (map (theta >>!) ts) - T.TyVar i -> fromMaybe ty (Map.lookup i mp) + T.TNamed n ts -> T.TNamed n (map (theta >>!) ts) + T.TUnion ts -> T.TUnion (Set.map (theta >>!) ts) + T.TyVar T.Instantiable i -> fromMaybe ty (Map.lookup i mp) + T.TyVar T.Rigid i + | i `Map.member` mp -> error "Attempt to substitute a rigid type variable" + | otherwise -> ty instance Substitute T.TypeScheme where Subst mp >>! T.TypeScheme bnds ty = T.TypeScheme bnds (Subst (foldr Map.delete mp bnds) >>! ty) instance Substitute Env where - theta >>! Env mp = Env (Map.map (theta >>!) mp) + theta >>! Env mp tdefs aliases = + Env (Map.map (theta >>!) mp) tdefs aliases -- TODO: make this instance unnecessary instance Substitute T.Expr where @@ -94,6 +89,7 @@ instance Substitute T.Expr where _ >>! expr@(T.Int _) = expr theta >>! T.Tup es = T.Tup (map (theta >>!) es) theta >>! T.Var (T.Occ name ty) = T.Var (T.Occ name (theta >>! ty)) + theta >>! T.Constr ty n = T.Constr (theta >>! ty) n instance Semigroup Subst where @@ -103,20 +99,43 @@ instance Monoid Subst where mempty = Subst mempty emptyEnv :: Env -emptyEnv = Env mempty +emptyEnv = Env mempty mempty mempty + +envAddDef :: Name -> T.TypeScheme -> Env -> Env +envAddDef name sty (Env mp tmp aliases) + | name `Map.member` mp = error "envAddDef on name already in environment" + | otherwise = + Env (Map.insert name sty mp) tmp aliases + +envFindDef :: Name -> Env -> Maybe T.TypeScheme +envFindDef name (Env mp _ _) = Map.lookup name mp + +envAddTypes :: Map Name T.TypeDef -> Env -> Env +envAddTypes l (Env mp tdefs aliases) = + let combined = l <> tdefs + in if Map.size combined == Map.size l + Map.size tdefs + then Env mp combined aliases + else error "envAddTypes on duplicate type names" + +envFindType :: Name -> Env -> Maybe T.TypeDef +envFindType name (Env _ tdefs _) = Map.lookup name tdefs -envAdd :: Name -> T.TypeScheme -> Env -> Env -envAdd name sty (Env mp) = Env (Map.insert name sty mp) +envAddAliases :: Map Name S.AliasDef -> Env -> Env +envAddAliases l (Env mp tdefs aliases) = + let combined = l <> aliases + in if Map.size combined == Map.size l + Map.size aliases + then Env mp tdefs combined + else error "envAddAliaes on duplicate type names" -envFind :: Name -> Env -> Maybe T.TypeScheme -envFind name (Env mp) = Map.lookup name mp +envAliases :: Env -> Map Name S.AliasDef +envAliases (Env _ _ aliases) = aliases substVar :: Int -> T.Type -> Subst substVar var ty = Subst (Map.singleton var ty) generalise :: Env -> T.Type -> T.TypeScheme generalise env ty = - T.TypeScheme (Set.toList (freeTypeVars ty Set.\\ freeTypeVars env)) ty + T.TypeScheme (Set.toList (freeInstTypeVars ty Set.\\ freeInstTypeVars env)) ty instantiate :: T.TypeScheme -> TM T.Type instantiate (T.TypeScheme bnds ty) = do @@ -124,6 +143,9 @@ instantiate (T.TypeScheme bnds ty) = do let theta = Subst (Map.fromList (zip bnds vars)) return (theta >>! ty) +freshenFrees :: Env -> T.Type -> TM T.Type +freshenFrees env = instantiate . generalise env + data UnifyContext = UnifyContext SourceRange T.Type T.Type unify :: SourceRange -> T.Type -> T.Type -> TM Subst @@ -131,25 +153,22 @@ unify sr t1 t2 = unify' (UnifyContext sr t1 t2) t1 t2 unify' :: UnifyContext -> T.Type -> T.Type -> TM Subst unify' _ T.TInt T.TInt = return mempty -unify' ctx (T.TFun t1 t2) (T.TFun u1 u2) = (<>) <$> unify' ctx t1 u1 <*> unify' ctx t2 u2 +unify' ctx (T.TFun t1 t2) (T.TFun u1 u2) = + (<>) <$> unify' ctx t1 u1 <*> unify' ctx t2 u2 unify' ctx (T.TTup ts) (T.TTup us) | length ts == length us = mconcat <$> zipWithM (unify' ctx) ts us -unify' _ (T.TyVar var) ty = return (substVar var ty) -unify' _ ty (T.TyVar var) = return (substVar var ty) +unify' _ (T.TyVar T.Instantiable var) ty = return (substVar var ty) +unify' _ ty (T.TyVar T.Instantiable var) = return (substVar var ty) +-- TODO: fix unify unify' (UnifyContext sr t1 t2) _ _ = throwError (TypeError sr t1 t2) -convertType :: S.Type -> T.Type -convertType (S.TFun t1 t2) = T.TFun (convertType t1) (convertType t2) -convertType S.TInt = T.TInt -convertType (S.TTup ts) = T.TTup (map convertType ts) - infer :: Env -> S.Expr -> TM (Subst, T.Expr) infer env expr = case expr of S.Lam _ [] body -> infer env body S.Lam sr args@(_:_:_) body -> infer env (foldr (S.Lam sr . pure) body args) S.Lam _ [(arg, _)] body -> do argVar <- genTyVar - let augEnv = envAdd arg (T.TypeScheme [] argVar) env + let augEnv = envAddDef arg (T.TypeScheme [] argVar) env (theta, body') <- infer augEnv body let argType = theta >>! argVar return (theta, T.Lam (T.TFun argType (T.exprType body')) @@ -157,7 +176,7 @@ infer env expr = case expr of S.Let _ (name, _) rhs body -> do (theta1, rhs') <- infer env rhs let varType = T.exprType rhs' - let augEnv = envAdd name (T.TypeScheme [] varType) env + let augEnv = envAddDef name (T.TypeScheme [] varType) env (theta2, body') <- infer augEnv body return (theta2 <> theta1, T.Let (T.Occ name varType) rhs' body') S.Call sr func arg -> do @@ -173,14 +192,22 @@ infer env expr = case expr of S.Int _ val -> return (mempty, T.Int val) S.Tup _ es -> fmap T.Tup <$> inferList env es S.Var sr name - | Just sty <- envFind name env -> do + | Just sty <- envFindDef name env -> do ty <- instantiate sty return (mempty, T.Var (T.Occ name ty)) | otherwise -> throwError (RefError sr name) + S.Constr sr name -> case envFindType name env of + Just (T.TypeDef typname params typ) -> do + restyp <- freshenFrees emptyEnv + (T.TNamed typname (map (T.TyVar T.Instantiable) params)) + return (mempty, T.Constr (T.TFun typ restyp) name) + _ -> + throwError (RefError sr name) S.Annot sr subex ty -> do (theta1, subex') <- infer env subex - theta2 <- unify sr (T.exprType subex') (convertType ty) + ty' <- convertType (envAliases env) sr ty + theta2 <- unify sr (T.exprType subex') ty' return (theta2 <> theta1, theta2 >>! subex') -- TODO: quadratic complexity inferList :: Env -> [S.Expr] -> TM (Subst, [T.Expr]) @@ -192,23 +219,44 @@ inferList env (expr : exprs) = do runPass :: Context -> S.Program -> Either TCError T.Program -runPass (Context _ (Builtins builtins)) prog = - let env = Env (Map.map (generalise emptyEnv) builtins) +runPass (Context _ (Builtins builtins _)) prog = + let env = Env (Map.map (generalise emptyEnv) builtins) mempty mempty in runTM (typeCheck env prog) typeCheck :: Env -> S.Program -> TM T.Program -typeCheck startEnv (S.Program decls) = - let defs = [(name, ty) - | S.Def (S.Function (Just ty) (name, _) _ _) <- decls] - env = foldl (\env' (name, ty) -> envAdd name (generalise env' (convertType ty)) env') - startEnv defs - in T.Program <$> mapM (typeCheckDef env . (\(S.Def def) -> def)) decls - -typeCheckDef :: Env -> S.Def -> TM T.Def -typeCheckDef env (S.Function mannot (name, sr) args@(_:_) body) = - typeCheckDef env (S.Function mannot (name, sr) [] (S.Lam sr args body)) -typeCheckDef env (S.Function (Just annot) (name, sr) [] body) = - typeCheckDef env (S.Function Nothing (name, sr) [] (S.Annot sr body annot)) -typeCheckDef env (S.Function Nothing (name, _) [] body) = do +typeCheck startEnv (S.Program decls) = do + traceM (show decls) + + let aliasdefs = [(n, def) + | S.DeclAlias def@(S.AliasDef (n, _) _ _) <- decls] + env1 = envAddAliases (Map.fromList aliasdefs) startEnv + + typedefs' <- checkTypedefs (envAliases env1) [def | S.DeclType def <- decls] + let typedefsMap = Map.fromList [(n, def) | def@(T.TypeDef n _ _) <- typedefs'] + + let funcdefs = [def | S.DeclFunc def <- decls] + typedfuncs <- sequence + [(name,) <$> convertType (envAliases env1) sr ty + | S.FuncDef (Just ty) (name, sr) _ _ <- funcdefs] + + let env2 = envAddTypes typedefsMap env1 + + traceM (show typedefsMap) + + let env = foldl (\env' (name, ty) -> + envAddDef name (generalise env' ty) env') + env2 typedfuncs + + traceM (show env) + + funcdefs' <- mapM (typeCheckFunc env) funcdefs + return (T.Program funcdefs' typedefsMap) + +typeCheckFunc :: Env -> S.FuncDef -> TM T.Def +typeCheckFunc env (S.FuncDef mannot (name, sr) args@(_:_) body) = + typeCheckFunc env (S.FuncDef mannot (name, sr) [] (S.Lam sr args body)) +typeCheckFunc env (S.FuncDef (Just annot) (name, sr) [] body) = + typeCheckFunc env (S.FuncDef Nothing (name, sr) [] (S.Annot sr body annot)) +typeCheckFunc env (S.FuncDef Nothing (name, _) [] body) = do (_, body') <- infer env body return (T.Def name body') diff --git a/typecheck/CC/Typecheck/Typedefs.hs b/typecheck/CC/Typecheck/Typedefs.hs new file mode 100644 index 0000000..ad9bdd8 --- /dev/null +++ b/typecheck/CC/Typecheck/Typedefs.hs @@ -0,0 +1,50 @@ +module CC.Typecheck.Typedefs(checkTypedefs) where + +import Control.Monad.Except +import Data.Foldable (traverse_) +import qualified Data.Map.Strict as Map +import Data.Map.Strict (Map) +import qualified Data.Set as Set + +import qualified CC.AST.Source as S +import qualified CC.AST.Typed as T +import CC.Typecheck.Types +import CC.Types + + +checkArity :: Map Name Int -> S.TypeDef -> TM () +checkArity typeArity (S.TypeDef (_, sr) args ty) = go ty + where + argNames = map fst args -- probably a small list + + go :: S.Type -> TM () + go (S.TFun t1 t2) = go t1 >> go t2 + go S.TInt = return () + go (S.TTup ts) = mapM_ go ts + go (S.TNamed n ts) + | Just arity <- Map.lookup n typeArity = + if length ts == arity + then mapM_ go ts + else throwError (TypeArityError sr n arity (length ts)) + | otherwise = throwError (RefError sr n) + go (S.TUnion ts) = traverse_ go ts + go (S.TyVar n) + | n `elem` argNames = return () + | otherwise = throwError (RefError sr n) + +checkTypedefs :: Map Name S.AliasDef -> [S.TypeDef] -> TM [T.TypeDef] +checkTypedefs aliases origdefs = do + let aliasArity = Map.map (\(S.AliasDef _ args _) -> length args) aliases + typeArity = Map.fromList [(n, length args) + | S.TypeDef (n, _) args _ <- origdefs] + + let dups = Set.fromList (map (\(S.TypeDef (n, _) _ _) -> n) origdefs) + Set.\\ Map.keysSet typeArity + when (not (Set.null dups)) $ + throwError (DupTypeError (Set.findMin dups)) + + let aliasdefs = [S.TypeDef name args typ + | S.AliasDef name args typ <- Map.elems aliases] + + mapM_ (checkArity (aliasArity <> typeArity)) (aliasdefs ++ origdefs) + mapM (convertTypeDef aliases) origdefs diff --git a/typecheck/CC/Typecheck/Types.hs b/typecheck/CC/Typecheck/Types.hs new file mode 100644 index 0000000..3f3c471 --- /dev/null +++ b/typecheck/CC/Typecheck/Types.hs @@ -0,0 +1,103 @@ +module CC.Typecheck.Types where + +import Control.Monad.State.Strict +import Control.Monad.Except +import qualified Data.Map.Strict as Map +import Data.Map.Strict (Map) +import Data.Maybe (fromMaybe) +import qualified Data.Set as Set +import Data.Set (Set) + +import qualified CC.AST.Source as S +import qualified CC.AST.Typed as T +import CC.Pretty +import CC.Types + + +data TCError = TypeError SourceRange T.Type T.Type + | RefError SourceRange Name + | TypeArityError SourceRange Name Int Int + | DupTypeError Name + deriving (Show) + +instance Pretty TCError where + pretty (TypeError sr real expect) = + "Type error: Expression at " ++ pretty sr ++ + " has type " ++ pretty real ++ + ", but should have type " ++ pretty expect + pretty (RefError sr name) = + "Reference error: Variable '" ++ name ++ "' out of scope at " ++ pretty sr + pretty (TypeArityError sr name wanted got) = + "Type error: Type '" ++ name ++ "' has arity " ++ show wanted ++ + " but gets " ++ show got ++ " type arguments at " ++ pretty sr + pretty (DupTypeError name) = + "Duplicate types: Type '" ++ name ++ "' defined multiple times" + +type TM a = ExceptT TCError (State Int) a + +genId :: TM Int +genId = state (\idval -> (idval, idval + 1)) + +genTyVar :: TM T.Type +genTyVar = T.TyVar T.Instantiable <$> genId + +runTM :: TM a -> Either TCError a +runTM m = evalState (runExceptT m) 1 + + +convertType :: Map Name S.AliasDef -> SourceRange -> S.Type -> TM T.Type +convertType aliases sr = fmap snd . convertType' aliases mempty sr + +convertTypeDef :: Map Name S.AliasDef -> S.TypeDef -> TM T.TypeDef +convertTypeDef aliases (S.TypeDef (name, sr) args ty) = do + (mapping, ty') <- convertType' aliases (Set.fromList (map fst args)) sr ty + let args' = [mapping Map.! n | (n, _) <- args] + return (T.TypeDef name args' ty') + +convertType' :: Map Name S.AliasDef -> Set Name -> SourceRange -> S.Type -> TM (Map Name Int, T.Type) +convertType' aliases extraVars sr origtype = do + rewritten <- rewrite origtype + let frees = Set.toList (extraVars <> freeVars rewritten) + nums <- traverse (const genId) frees + let mapping = Map.fromList (zip frees nums) + return (mapping, convert mapping rewritten) + where + rewrite :: S.Type -> TM S.Type + rewrite (S.TFun t1 t2) = S.TFun <$> rewrite t1 <*> rewrite t2 + rewrite S.TInt = return S.TInt + rewrite (S.TTup ts) = S.TTup <$> mapM rewrite ts + rewrite (S.TNamed n ts) + | Just (S.AliasDef _ args typ) <- Map.lookup n aliases = + if length args == length ts + then rewrite (subst (Map.fromList (zip (map fst args) ts)) typ) + else throwError (TypeArityError sr n (length args) (length ts)) + | otherwise = + S.TNamed n <$> mapM rewrite ts + rewrite (S.TUnion ts) = S.TUnion . Set.fromList <$> mapM rewrite (Set.toList ts) + rewrite (S.TyVar n) = return (S.TyVar n) + + -- Substitute type variables + subst :: Map Name S.Type -> S.Type -> S.Type + subst mp (S.TFun t1 t2) = S.TFun (subst mp t1) (subst mp t2) + subst _ S.TInt = S.TInt + subst mp (S.TTup ts) = S.TTup (map (subst mp) ts) + subst mp (S.TNamed n ts) = S.TNamed n (map (subst mp) ts) + subst mp (S.TUnion ts) = S.TUnion (Set.map (subst mp) ts) + subst mp orig@(S.TyVar n) = fromMaybe orig (Map.lookup n mp) + + freeVars :: S.Type -> Set Name + freeVars (S.TFun t1 t2) = freeVars t1 <> freeVars t2 + freeVars S.TInt = mempty + freeVars (S.TTup ts) = Set.unions (map freeVars ts) + freeVars (S.TNamed _ ts) = Set.unions (map freeVars ts) + freeVars (S.TUnion ts) = Set.unions (map freeVars (Set.toList ts)) + freeVars (S.TyVar n) = Set.singleton n + + convert :: Map Name Int -> S.Type -> T.Type + convert mp (S.TFun t1 t2) = T.TFun (convert mp t1) (convert mp t2) + convert _ S.TInt = T.TInt + convert mp (S.TTup ts) = T.TTup (map (convert mp) ts) + convert mp (S.TNamed n ts) = T.TNamed n (map (convert mp) ts) + convert mp (S.TUnion ts) = T.TUnion (Set.map (convert mp) ts) + -- TODO: Should this be Rigid? I really don't know how this works. + convert mp (S.TyVar n) = T.TyVar T.Rigid (mp Map.! n) -- cgit v1.2.3