aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom.smeding@gmail.com>2020-07-26 23:02:09 +0200
committerTom Smeding <tom.smeding@gmail.com>2020-07-26 23:02:09 +0200
commit342c213f3caddd64db0eac5ae146912e00378371 (patch)
tree80f55eb7ccabf24ea0787db428595cdbf6caffe0
parent494b764274be4db53499fa4eb7decacb93c7bbe9 (diff)
WIP refactor and union types, type variables
-rw-r--r--ast/CC/AST/Source.hs37
-rw-r--r--ast/CC/AST/Typed.hs48
-rw-r--r--ast/CC/Context.hs2
-rw-r--r--backend/CC/Backend/Dumb.hs24
-rw-r--r--compcomp.cabal3
-rw-r--r--parser/CC/Parser.hs97
-rw-r--r--typecheck/CC/Typecheck.hs182
-rw-r--r--typecheck/CC/Typecheck/Typedefs.hs50
-rw-r--r--typecheck/CC/Typecheck/Types.hs103
9 files changed, 435 insertions, 111 deletions
diff --git a/ast/CC/AST/Source.hs b/ast/CC/AST/Source.hs
index e648759..e64e058 100644
--- a/ast/CC/AST/Source.hs
+++ b/ast/CC/AST/Source.hs
@@ -3,6 +3,8 @@ module CC.AST.Source(
module CC.Types
) where
+import qualified Data.Set as Set
+import Data.Set (Set)
import Data.List
import CC.Pretty
@@ -12,19 +14,32 @@ import CC.Types
data Program = Program [Decl]
deriving (Show, Read)
-data Decl = Def Def -- import?
+data Decl = DeclFunc FuncDef
+ | DeclType TypeDef
+ | DeclAlias AliasDef
deriving (Show, Read)
-data Def = Function (Maybe Type)
- (Name, SourceRange)
- [(Name, SourceRange)]
- Expr
+data FuncDef =
+ FuncDef (Maybe Type)
+ (Name, SourceRange)
+ [(Name, SourceRange)]
+ Expr
+ deriving (Show, Read)
+
+-- Named type with named arguments
+data TypeDef = TypeDef (Name, SourceRange) [(Name, SourceRange)] Type
+ deriving (Show, Read)
+
+data AliasDef = AliasDef (Name, SourceRange) [(Name, SourceRange)] Type
deriving (Show, Read)
data Type = TFun Type Type
| TInt
| TTup [Type]
- deriving (Show, Read)
+ | TNamed Name [Type] -- named type with type arguments
+ | TUnion (Set Type)
+ | TyVar Name
+ deriving (Eq, Ord, Show, Read)
data Expr = Lam SourceRange [(Name, SourceRange)] Expr
| Let SourceRange (Name, SourceRange) Expr Expr
@@ -32,15 +47,24 @@ data Expr = Lam SourceRange [(Name, SourceRange)] Expr
| Int SourceRange Int
| Tup SourceRange [Expr]
| Var SourceRange Name
+ | Constr SourceRange Name -- type constructor
| Annot SourceRange Expr Type
deriving (Show, Read)
+instance Semigroup Program where Program p1 <> Program p2 = Program (p1 <> p2)
+instance Monoid Program where mempty = Program mempty
+
instance Pretty Type where
prettyPrec _ TInt = "Int"
prettyPrec p (TFun a b) =
precParens p 2 (prettyPrec 3 a ++ " -> " ++ prettyPrec 2 b)
prettyPrec _ (TTup ts) =
"(" ++ intercalate ", " (map pretty ts) ++ ")"
+ prettyPrec _ (TNamed n ts) =
+ n ++ "[" ++ intercalate ", " (map pretty ts) ++ "]"
+ prettyPrec _ (TUnion ts) =
+ "{" ++ intercalate " | " (map pretty (Set.toList ts)) ++ "}"
+ prettyPrec _ (TyVar n) = "<" ++ n ++ ">"
instance HasRange Expr where
range (Lam sr _ _) = sr
@@ -49,4 +73,5 @@ instance HasRange Expr where
range (Int sr _) = sr
range (Tup sr _) = sr
range (Var sr _) = sr
+ range (Constr sr _) = sr
range (Annot sr _ _) = sr
diff --git a/ast/CC/AST/Typed.hs b/ast/CC/AST/Typed.hs
index b12b30a..cf67575 100644
--- a/ast/CC/AST/Typed.hs
+++ b/ast/CC/AST/Typed.hs
@@ -3,23 +3,36 @@ module CC.AST.Typed(
module CC.Types
) where
+import qualified Data.Map.Strict as Map
+import Data.Map.Strict (Map)
+import qualified Data.Set as Set
+import Data.Set (Set)
import Data.List
import CC.Pretty
import CC.Types
-data Program = Program [Def]
+data Program = Program [Def] (Map Name TypeDef)
deriving (Show, Read)
data Def = Def Name Expr
deriving (Show, Read)
+-- Named type with type parameters
+data TypeDef = TypeDef Name [Int] Type
+ deriving (Show, Read)
+
data Type = TFun Type Type
| TInt
| TTup [Type]
- | TyVar Int
- deriving (Show, Read)
+ | TNamed Name [Type] -- named type with type arguments
+ | TUnion (Set Type)
+ | TyVar Rigidity Int
+ deriving (Eq, Ord, Show, Read)
+
+data Rigidity = Rigid | Instantiable
+ deriving (Eq, Ord, Show, Read)
data TypeScheme = TypeScheme [Int] Type
deriving (Show, Read)
@@ -30,6 +43,7 @@ data Expr = Lam Type Occ Expr
| Int Int
| Tup [Expr]
| Var Occ
+ | Constr Type Name -- Type is 'argument -> TNamed'
deriving (Show, Read)
data Occ = Occ Name Type
@@ -42,6 +56,7 @@ exprType (Call typ _ _) = typ
exprType (Int _) = TInt
exprType (Tup es) = TTup (map exprType es)
exprType (Var (Occ _ typ)) = typ
+exprType (Constr typ _) = typ
instance Pretty Type where
prettyPrec _ TInt = "Int"
@@ -49,13 +64,18 @@ instance Pretty Type where
precParens p 2 (prettyPrec 3 a ++ " -> " ++ prettyPrec 2 b)
prettyPrec _ (TTup ts) =
"(" ++ intercalate ", " (map pretty ts) ++ ")"
- prettyPrec _ (TyVar i) = 't' : show i
+ prettyPrec _ (TNamed n ts) =
+ n ++ "[" ++ intercalate ", " (map pretty ts) ++ "]"
+ prettyPrec _ (TUnion ts) =
+ "{ " ++ intercalate " | " (map pretty (Set.toList ts)) ++ " }"
+ prettyPrec _ (TyVar Rigid i) = 't' : show i ++ "R"
+ prettyPrec _ (TyVar Instantiable i) = 't' : show i
instance Pretty TypeScheme where
prettyPrec p (TypeScheme bnds ty) =
precParens p 2
- ("forall " ++ intercalate " " (map (pretty . TyVar) bnds) ++ ". " ++
- prettyPrec 2 ty)
+ ("forall " ++ intercalate " " (map (pretty . TyVar Instantiable) bnds) ++
+ ". " ++ prettyPrec 2 ty)
instance Pretty Expr where
prettyPrec p (Lam ty (Occ n t) e) =
@@ -75,10 +95,22 @@ instance Pretty Expr where
prettyPrec _ (Tup es) = "(" ++ intercalate ", " (map pretty es) ++ ")"
prettyPrec p (Var (Occ n t)) =
precParens p 2 $
- show n ++ " :: " ++ pretty t
+ n ++ " :: " ++ pretty t
+ prettyPrec p (Constr t n) =
+ precParens p 2 $
+ n ++ " :: " ++ pretty t
instance Pretty Def where
pretty (Def n e) = n ++ " = " ++ pretty e
+instance Pretty TypeDef where
+ pretty (TypeDef n [] t) =
+ "type " ++ n ++ " = " ++ pretty t
+ pretty (TypeDef n vs t) =
+ "type " ++ n ++ " " ++ intercalate " " (map (pretty . TyVar Instantiable) vs) ++
+ " = " ++ pretty t
+
instance Pretty Program where
- pretty (Program defs) = concatMap ((++ "\n") . pretty) defs
+ pretty (Program defs tdefs) =
+ concatMap (++ "\n") $
+ map pretty (Map.elems tdefs) ++ map pretty defs
diff --git a/ast/CC/Context.hs b/ast/CC/Context.hs
index 68378d7..acfb614 100644
--- a/ast/CC/Context.hs
+++ b/ast/CC/Context.hs
@@ -9,4 +9,4 @@ import CC.AST.Typed
data Context = Context FilePath Builtins
-- | Information about builtins supported by the enabled backend
-data Builtins = Builtins (Map Name Type)
+data Builtins = Builtins (Map Name Type) String
diff --git a/backend/CC/Backend/Dumb.hs b/backend/CC/Backend/Dumb.hs
index 3210dab..822fb84 100644
--- a/backend/CC/Backend/Dumb.hs
+++ b/backend/CC/Backend/Dumb.hs
@@ -7,10 +7,20 @@ import CC.Context
builtins :: Builtins
-builtins = Builtins . Map.fromList $
- [ ("print", TFun TInt (TTup []))
- , ("fst", TFun (TTup [TyVar 1, TyVar 2]) (TyVar 1))
- , ("snd", TFun (TTup [TyVar 1, TyVar 2]) (TyVar 2))
- , ("_add", TFun TInt (TFun TInt TInt))
- , ("_sub", TFun TInt (TFun TInt TInt))
- , ("_mul", TFun TInt (TFun TInt TInt)) ]
+builtins =
+ let values = Map.fromList
+ [ ("print", TInt ==> TTup [])
+ , ("fst", TTup [t1, t2] ==> t1)
+ , ("snd", TTup [t1, t2] ==> t2)
+ , ("_add", TInt ==> TInt ==> TInt)
+ , ("_sub", TInt ==> TInt ==> TInt)
+ , ("_mul", TInt ==> TInt ==> TInt) ]
+ prelude = "type Nil = ()\n\
+ \type Cons a = (a, List a)\n\
+ \alias List a = { Nil | Cons a }\n"
+ in Builtins values prelude
+ where
+ t1 = TyVar Instantiable 1
+ t2 = TyVar Instantiable 2
+ infixr ==>
+ (==>) = TFun
diff --git a/compcomp.cabal b/compcomp.cabal
index 9e941aa..4b06f4b 100644
--- a/compcomp.cabal
+++ b/compcomp.cabal
@@ -20,7 +20,7 @@ executable compcomp
library cc-parser
import: deps
hs-source-dirs: parser
- build-depends: parsec, cc-ast, cc-utils
+ build-depends: containers, parsec, cc-ast, cc-utils
exposed-modules: CC.Parser
library cc-typecheck
@@ -28,6 +28,7 @@ library cc-typecheck
hs-source-dirs: typecheck
build-depends: containers, mtl, cc-ast, cc-utils
exposed-modules: CC.Typecheck
+ other-modules: CC.Typecheck.Typedefs, CC.Typecheck.Types
library cc-backend-dumb
import: deps
diff --git a/parser/CC/Parser.hs b/parser/CC/Parser.hs
index 2d2c4b7..66bc6cf 100644
--- a/parser/CC/Parser.hs
+++ b/parser/CC/Parser.hs
@@ -1,6 +1,7 @@
module CC.Parser(runPass, parseProgram) where
import Control.Monad
+import qualified Data.Set as Set
import Text.Parsec hiding (SourcePos, getPosition, token)
import qualified Text.Parsec
@@ -12,7 +13,10 @@ import CC.Pretty
type Parser a = Parsec String () a
runPass :: Context -> RawString -> Either (PrettyShow ParseError) Program
-runPass (Context path _) (RawString src) = fmapLeft PrettyShow (parseProgram path src)
+runPass (Context path (Builtins _ prelude)) (RawString src) = do
+ prog1 <- fmapLeft PrettyShow (parseProgram "<prelude>" prelude)
+ prog2 <- fmapLeft PrettyShow (parseProgram path src)
+ return (prog1 <> prog2)
where fmapLeft f (Left x) = Left (f x)
fmapLeft _ (Right x) = Right x
@@ -27,32 +31,63 @@ pProgram = do
return prog
pDecl :: Parser Decl
-pDecl = Def <$> pDef
+pDecl = choice
+ [ DeclType <$> pDeclType
+ , DeclAlias <$> pDeclAlias
+ , DeclFunc <$> pDeclFunc ]
-pDef :: Parser Def
-pDef = do
+pDeclFunc :: Parser FuncDef
+pDeclFunc = do
func <- try $ do
emptyLines
- name <- pName0 <?> "declaration head name"
+ name <- pName0 LowerCase <?> "declaration head name"
return name
mtyp <- optionMaybe $ do
symbol "::"
typ <- pType
whitespace >> void newline
emptyLines
- func' <- fst <$> pName0
+ (func', _) <- pName0 LowerCase
guard (fst func == func')
return typ
- args <- many pName
+ args <- many (pName LowerCase)
symbol "="
expr <- pExpr
- return (Function mtyp func args expr)
+ return (FuncDef mtyp func args expr)
+
+pDeclType :: Parser TypeDef
+pDeclType = (\(n, a, t) -> TypeDef n a t) <$> pTypedefLike "type"
+
+pDeclAlias :: Parser AliasDef
+pDeclAlias = (\(n, a, t) -> AliasDef n a t) <$> pTypedefLike "alias"
+
+pTypedefLike :: String -> Parser ((Name, SourceRange), [(Name, SourceRange)], Type)
+pTypedefLike keyword = do
+ try (emptyLines >> string keyword >> whitespace1)
+ name <- pName0 UpperCase
+ args <- many (pName LowerCase)
+ symbol "="
+ ty <- pType
+ return (name, args, ty)
pType :: Parser Type
-pType = chainr1 pTypeAtom (symbol "->" >> return TFun)
+pType = chainr1 pTypeTerm (symbol "->" >> return TFun)
+
+pTypeTerm :: Parser Type
+pTypeTerm = pTypeAtom <|> pTypeCall
pTypeAtom :: Parser Type
-pTypeAtom = (wordToken "Int" >> return TInt) <|> pParenType
+pTypeAtom = choice
+ [ wordToken "Int" >> return TInt
+ , TyVar . fst <$> pName LowerCase
+ , pParenType
+ , pUnionType ]
+
+pTypeCall :: Parser Type
+pTypeCall = do
+ (constr, _) <- pName UpperCase
+ args <- many pTypeAtom
+ return (TNamed constr args)
pParenType :: Parser Type
pParenType = do
@@ -63,6 +98,19 @@ pParenType = do
[ty] -> return ty
_ -> return (TTup tys)
+pUnionType :: Parser Type
+pUnionType = do
+ token "{"
+ tys <- pType `sepBy` token "|"
+ token "}"
+ case tys of
+ [] -> unexpected "empty union type"
+ [ty] -> return ty
+ _ -> let tyset = Set.fromList tys
+ in if Set.size tyset == length tys
+ then return (TUnion tyset)
+ else unexpected "duplicate types in union"
+
pExpr :: Parser Expr
pExpr = label (pLam <|> pLet <|> pCall) "expression"
where
@@ -84,7 +132,7 @@ pExpr = label (pLam <|> pLet <|> pCall) "expression"
p <- getPosition
void (char '\\')
return p
- names <- many1 pName
+ names <- many1 (pName LowerCase)
symbol "->"
body <- pExpr
p2 <- getPosition
@@ -100,7 +148,7 @@ pExpr = label (pLam <|> pLet <|> pCall) "expression"
where
afterKeyword p1 = do
whitespace1
- lhs <- pName0
+ lhs <- pName0 LowerCase
symbol "="
rhs <- pExpr
let fullRange rest = mergeRange (SourceRange p1 p1) (range rest)
@@ -118,7 +166,8 @@ pExpr = label (pLam <|> pLet <|> pCall) "expression"
pExprAtom :: Parser Expr
pExprAtom =
choice [ uncurry (flip Int) <$> pInt
- , uncurry (flip Var) <$> pName
+ , uncurry (flip Var) <$> pName LowerCase
+ , uncurry (flip Constr) <$> pName UpperCase
, pParenExpr ]
pParenExpr :: Parser Expr
@@ -141,11 +190,13 @@ pInt = try (whitespace >> pInt0)
p2 <- getPosition
return (num, SourceRange p1 p2)
-pName0 :: Parser (Name, SourceRange)
-pName0 = do
+data Case = LowerCase | UpperCase
+
+pName0 :: Case -> Parser (Name, SourceRange)
+pName0 vcase = do
p1 <- getPosition
s <- try $ do
- c <- pWordFirstChar
+ c <- pWordFirstChar vcase
cs <- many pWordMidChar
let s = c : cs
guard (s `notElem` ["let", "in"])
@@ -154,14 +205,18 @@ pName0 = do
notFollowedBy pWordMidChar
return (s, SourceRange p1 p2)
-pWordFirstChar :: Parser Char
-pWordFirstChar = letter <|> oneOf "_$#!"
+pWordFirstChar :: Case -> Parser Char
+pWordFirstChar LowerCase = lower <|> oneOf wordSymbols
+pWordFirstChar UpperCase = upper <|> oneOf wordSymbols
pWordMidChar :: Parser Char
-pWordMidChar = alphaNum <|> oneOf "_$#!"
+pWordMidChar = alphaNum <|> oneOf wordSymbols
+
+wordSymbols :: [Char]
+wordSymbols = "_$#!"
-pName :: Parser (Name, SourceRange)
-pName = try (whitespace >> pName0)
+pName :: Case -> Parser (Name, SourceRange)
+pName vcase = try (whitespace >> pName0 vcase)
symbol :: String -> Parser ()
symbol s = token s >> (eof <|> void space <|> void (oneOf "(){}[]"))
diff --git a/typecheck/CC/Typecheck.hs b/typecheck/CC/Typecheck.hs
index f61103e..824a714 100644
--- a/typecheck/CC/Typecheck.hs
+++ b/typecheck/CC/Typecheck.hs
@@ -1,4 +1,5 @@
{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeSynonymInstances #-}
module CC.Typecheck(runPass) where
@@ -10,59 +11,47 @@ import Data.Maybe
import qualified Data.Set as Set
import Data.Set (Set)
+import Debug.Trace
+
import qualified CC.AST.Source as S
import qualified CC.AST.Typed as T
import CC.Context
-import CC.Pretty
import CC.Types
+import CC.Typecheck.Typedefs
+import CC.Typecheck.Types
-- Inspiration: https://github.com/kritzcreek/fby19
-data TCError = TypeError SourceRange T.Type T.Type
- | RefError SourceRange Name
+data Env =
+ Env (Map Name T.TypeScheme) -- Definitions in scope
+ (Map Name T.TypeDef) -- Type definitions
+ (Map Name S.AliasDef) -- Type aliases
deriving (Show)
-instance Pretty TCError where
- pretty (TypeError sr real expect) =
- "Type error: Expression at " ++ pretty sr ++
- " has type " ++ pretty real ++
- ", but should have type " ++ pretty expect
- pretty (RefError sr name) =
- "Reference error: Variable '" ++ name ++ "' out of scope at " ++ pretty sr
-
-type TM a = ExceptT TCError (State Int) a
-
-genId :: TM Int
-genId = state (\idval -> (idval, idval + 1))
-
-genTyVar :: TM T.Type
-genTyVar = T.TyVar <$> genId
-
-runTM :: TM a -> Either TCError a
-runTM m = evalState (runExceptT m) 1
-
-
-newtype Env = Env (Map Name T.TypeScheme)
-
newtype Subst = Subst (Map Int T.Type)
class FreeTypeVars a where
- freeTypeVars :: a -> Set Int
+ -- Free instantiable type variables
+ freeInstTypeVars :: a -> Set Int
instance FreeTypeVars T.Type where
- freeTypeVars (T.TFun t1 t2) = freeTypeVars t1 <> freeTypeVars t2
- freeTypeVars T.TInt = mempty
- freeTypeVars (T.TTup ts) = Set.unions (map freeTypeVars ts)
- freeTypeVars (T.TyVar var) = Set.singleton var
+ freeInstTypeVars (T.TFun t1 t2) = freeInstTypeVars t1 <> freeInstTypeVars t2
+ freeInstTypeVars T.TInt = mempty
+ freeInstTypeVars (T.TTup ts) = Set.unions (map freeInstTypeVars ts)
+ freeInstTypeVars (T.TNamed _ ts) = Set.unions (map freeInstTypeVars ts)
+ freeInstTypeVars (T.TUnion ts) = Set.unions (map freeInstTypeVars (Set.toList ts))
+ freeInstTypeVars (T.TyVar T.Instantiable var) = Set.singleton var
+ freeInstTypeVars (T.TyVar T.Rigid _) = mempty
instance FreeTypeVars T.TypeScheme where
- freeTypeVars (T.TypeScheme bnds ty) = foldr Set.delete (freeTypeVars ty) bnds
+ freeInstTypeVars (T.TypeScheme bnds ty) =
+ foldr Set.delete (freeInstTypeVars ty) bnds
instance FreeTypeVars Env where
- freeTypeVars (Env mp) = foldMap freeTypeVars (Map.elems mp)
+ freeInstTypeVars (Env mp _ _) = foldMap freeInstTypeVars (Map.elems mp)
infixr >>!
@@ -74,14 +63,20 @@ instance Substitute T.Type where
T.TFun t1 t2 -> T.TFun (theta >>! t1) (theta >>! t2)
T.TInt -> T.TInt
T.TTup ts -> T.TTup (map (theta >>!) ts)
- T.TyVar i -> fromMaybe ty (Map.lookup i mp)
+ T.TNamed n ts -> T.TNamed n (map (theta >>!) ts)
+ T.TUnion ts -> T.TUnion (Set.map (theta >>!) ts)
+ T.TyVar T.Instantiable i -> fromMaybe ty (Map.lookup i mp)
+ T.TyVar T.Rigid i
+ | i `Map.member` mp -> error "Attempt to substitute a rigid type variable"
+ | otherwise -> ty
instance Substitute T.TypeScheme where
Subst mp >>! T.TypeScheme bnds ty =
T.TypeScheme bnds (Subst (foldr Map.delete mp bnds) >>! ty)
instance Substitute Env where
- theta >>! Env mp = Env (Map.map (theta >>!) mp)
+ theta >>! Env mp tdefs aliases =
+ Env (Map.map (theta >>!) mp) tdefs aliases
-- TODO: make this instance unnecessary
instance Substitute T.Expr where
@@ -94,6 +89,7 @@ instance Substitute T.Expr where
_ >>! expr@(T.Int _) = expr
theta >>! T.Tup es = T.Tup (map (theta >>!) es)
theta >>! T.Var (T.Occ name ty) = T.Var (T.Occ name (theta >>! ty))
+ theta >>! T.Constr ty n = T.Constr (theta >>! ty) n
instance Semigroup Subst where
@@ -103,20 +99,43 @@ instance Monoid Subst where
mempty = Subst mempty
emptyEnv :: Env
-emptyEnv = Env mempty
+emptyEnv = Env mempty mempty mempty
+
+envAddDef :: Name -> T.TypeScheme -> Env -> Env
+envAddDef name sty (Env mp tmp aliases)
+ | name `Map.member` mp = error "envAddDef on name already in environment"
+ | otherwise =
+ Env (Map.insert name sty mp) tmp aliases
+
+envFindDef :: Name -> Env -> Maybe T.TypeScheme
+envFindDef name (Env mp _ _) = Map.lookup name mp
+
+envAddTypes :: Map Name T.TypeDef -> Env -> Env
+envAddTypes l (Env mp tdefs aliases) =
+ let combined = l <> tdefs
+ in if Map.size combined == Map.size l + Map.size tdefs
+ then Env mp combined aliases
+ else error "envAddTypes on duplicate type names"
+
+envFindType :: Name -> Env -> Maybe T.TypeDef
+envFindType name (Env _ tdefs _) = Map.lookup name tdefs
-envAdd :: Name -> T.TypeScheme -> Env -> Env
-envAdd name sty (Env mp) = Env (Map.insert name sty mp)
+envAddAliases :: Map Name S.AliasDef -> Env -> Env
+envAddAliases l (Env mp tdefs aliases) =
+ let combined = l <> aliases
+ in if Map.size combined == Map.size l + Map.size aliases
+ then Env mp tdefs combined
+ else error "envAddAliaes on duplicate type names"
-envFind :: Name -> Env -> Maybe T.TypeScheme
-envFind name (Env mp) = Map.lookup name mp
+envAliases :: Env -> Map Name S.AliasDef
+envAliases (Env _ _ aliases) = aliases
substVar :: Int -> T.Type -> Subst
substVar var ty = Subst (Map.singleton var ty)
generalise :: Env -> T.Type -> T.TypeScheme
generalise env ty =
- T.TypeScheme (Set.toList (freeTypeVars ty Set.\\ freeTypeVars env)) ty
+ T.TypeScheme (Set.toList (freeInstTypeVars ty Set.\\ freeInstTypeVars env)) ty
instantiate :: T.TypeScheme -> TM T.Type
instantiate (T.TypeScheme bnds ty) = do
@@ -124,6 +143,9 @@ instantiate (T.TypeScheme bnds ty) = do
let theta = Subst (Map.fromList (zip bnds vars))
return (theta >>! ty)
+freshenFrees :: Env -> T.Type -> TM T.Type
+freshenFrees env = instantiate . generalise env
+
data UnifyContext = UnifyContext SourceRange T.Type T.Type
unify :: SourceRange -> T.Type -> T.Type -> TM Subst
@@ -131,25 +153,22 @@ unify sr t1 t2 = unify' (UnifyContext sr t1 t2) t1 t2
unify' :: UnifyContext -> T.Type -> T.Type -> TM Subst
unify' _ T.TInt T.TInt = return mempty
-unify' ctx (T.TFun t1 t2) (T.TFun u1 u2) = (<>) <$> unify' ctx t1 u1 <*> unify' ctx t2 u2
+unify' ctx (T.TFun t1 t2) (T.TFun u1 u2) =
+ (<>) <$> unify' ctx t1 u1 <*> unify' ctx t2 u2
unify' ctx (T.TTup ts) (T.TTup us)
| length ts == length us = mconcat <$> zipWithM (unify' ctx) ts us
-unify' _ (T.TyVar var) ty = return (substVar var ty)
-unify' _ ty (T.TyVar var) = return (substVar var ty)
+unify' _ (T.TyVar T.Instantiable var) ty = return (substVar var ty)
+unify' _ ty (T.TyVar T.Instantiable var) = return (substVar var ty)
+-- TODO: fix unify
unify' (UnifyContext sr t1 t2) _ _ = throwError (TypeError sr t1 t2)
-convertType :: S.Type -> T.Type
-convertType (S.TFun t1 t2) = T.TFun (convertType t1) (convertType t2)
-convertType S.TInt = T.TInt
-convertType (S.TTup ts) = T.TTup (map convertType ts)
-
infer :: Env -> S.Expr -> TM (Subst, T.Expr)
infer env expr = case expr of
S.Lam _ [] body -> infer env body
S.Lam sr args@(_:_:_) body -> infer env (foldr (S.Lam sr . pure) body args)
S.Lam _ [(arg, _)] body -> do
argVar <- genTyVar
- let augEnv = envAdd arg (T.TypeScheme [] argVar) env
+ let augEnv = envAddDef arg (T.TypeScheme [] argVar) env
(theta, body') <- infer augEnv body
let argType = theta >>! argVar
return (theta, T.Lam (T.TFun argType (T.exprType body'))
@@ -157,7 +176,7 @@ infer env expr = case expr of
S.Let _ (name, _) rhs body -> do
(theta1, rhs') <- infer env rhs
let varType = T.exprType rhs'
- let augEnv = envAdd name (T.TypeScheme [] varType) env
+ let augEnv = envAddDef name (T.TypeScheme [] varType) env
(theta2, body') <- infer augEnv body
return (theta2 <> theta1, T.Let (T.Occ name varType) rhs' body')
S.Call sr func arg -> do
@@ -173,14 +192,22 @@ infer env expr = case expr of
S.Int _ val -> return (mempty, T.Int val)
S.Tup _ es -> fmap T.Tup <$> inferList env es
S.Var sr name
- | Just sty <- envFind name env -> do
+ | Just sty <- envFindDef name env -> do
ty <- instantiate sty
return (mempty, T.Var (T.Occ name ty))
| otherwise ->
throwError (RefError sr name)
+ S.Constr sr name -> case envFindType name env of
+ Just (T.TypeDef typname params typ) -> do
+ restyp <- freshenFrees emptyEnv
+ (T.TNamed typname (map (T.TyVar T.Instantiable) params))
+ return (mempty, T.Constr (T.TFun typ restyp) name)
+ _ ->
+ throwError (RefError sr name)
S.Annot sr subex ty -> do
(theta1, subex') <- infer env subex
- theta2 <- unify sr (T.exprType subex') (convertType ty)
+ ty' <- convertType (envAliases env) sr ty
+ theta2 <- unify sr (T.exprType subex') ty'
return (theta2 <> theta1, theta2 >>! subex') -- TODO: quadratic complexity
inferList :: Env -> [S.Expr] -> TM (Subst, [T.Expr])
@@ -192,23 +219,44 @@ inferList env (expr : exprs) = do
runPass :: Context -> S.Program -> Either TCError T.Program
-runPass (Context _ (Builtins builtins)) prog =
- let env = Env (Map.map (generalise emptyEnv) builtins)
+runPass (Context _ (Builtins builtins _)) prog =
+ let env = Env (Map.map (generalise emptyEnv) builtins) mempty mempty
in runTM (typeCheck env prog)
typeCheck :: Env -> S.Program -> TM T.Program
-typeCheck startEnv (S.Program decls) =
- let defs = [(name, ty)
- | S.Def (S.Function (Just ty) (name, _) _ _) <- decls]
- env = foldl (\env' (name, ty) -> envAdd name (generalise env' (convertType ty)) env')
- startEnv defs
- in T.Program <$> mapM (typeCheckDef env . (\(S.Def def) -> def)) decls
-
-typeCheckDef :: Env -> S.Def -> TM T.Def
-typeCheckDef env (S.Function mannot (name, sr) args@(_:_) body) =
- typeCheckDef env (S.Function mannot (name, sr) [] (S.Lam sr args body))
-typeCheckDef env (S.Function (Just annot) (name, sr) [] body) =
- typeCheckDef env (S.Function Nothing (name, sr) [] (S.Annot sr body annot))
-typeCheckDef env (S.Function Nothing (name, _) [] body) = do
+typeCheck startEnv (S.Program decls) = do
+ traceM (show decls)
+
+ let aliasdefs = [(n, def)
+ | S.DeclAlias def@(S.AliasDef (n, _) _ _) <- decls]
+ env1 = envAddAliases (Map.fromList aliasdefs) startEnv
+
+ typedefs' <- checkTypedefs (envAliases env1) [def | S.DeclType def <- decls]
+ let typedefsMap = Map.fromList [(n, def) | def@(T.TypeDef n _ _) <- typedefs']
+
+ let funcdefs = [def | S.DeclFunc def <- decls]
+ typedfuncs <- sequence
+ [(name,) <$> convertType (envAliases env1) sr ty
+ | S.FuncDef (Just ty) (name, sr) _ _ <- funcdefs]
+
+ let env2 = envAddTypes typedefsMap env1
+
+ traceM (show typedefsMap)
+
+ let env = foldl (\env' (name, ty) ->
+ envAddDef name (generalise env' ty) env')
+ env2 typedfuncs
+
+ traceM (show env)
+
+ funcdefs' <- mapM (typeCheckFunc env) funcdefs
+ return (T.Program funcdefs' typedefsMap)
+
+typeCheckFunc :: Env -> S.FuncDef -> TM T.Def
+typeCheckFunc env (S.FuncDef mannot (name, sr) args@(_:_) body) =
+ typeCheckFunc env (S.FuncDef mannot (name, sr) [] (S.Lam sr args body))
+typeCheckFunc env (S.FuncDef (Just annot) (name, sr) [] body) =
+ typeCheckFunc env (S.FuncDef Nothing (name, sr) [] (S.Annot sr body annot))
+typeCheckFunc env (S.FuncDef Nothing (name, _) [] body) = do
(_, body') <- infer env body
return (T.Def name body')
diff --git a/typecheck/CC/Typecheck/Typedefs.hs b/typecheck/CC/Typecheck/Typedefs.hs
new file mode 100644
index 0000000..ad9bdd8
--- /dev/null
+++ b/typecheck/CC/Typecheck/Typedefs.hs
@@ -0,0 +1,50 @@
+module CC.Typecheck.Typedefs(checkTypedefs) where
+
+import Control.Monad.Except
+import Data.Foldable (traverse_)
+import qualified Data.Map.Strict as Map
+import Data.Map.Strict (Map)
+import qualified Data.Set as Set
+
+import qualified CC.AST.Source as S
+import qualified CC.AST.Typed as T
+import CC.Typecheck.Types
+import CC.Types
+
+
+checkArity :: Map Name Int -> S.TypeDef -> TM ()
+checkArity typeArity (S.TypeDef (_, sr) args ty) = go ty
+ where
+ argNames = map fst args -- probably a small list
+
+ go :: S.Type -> TM ()
+ go (S.TFun t1 t2) = go t1 >> go t2
+ go S.TInt = return ()
+ go (S.TTup ts) = mapM_ go ts
+ go (S.TNamed n ts)
+ | Just arity <- Map.lookup n typeArity =
+ if length ts == arity
+ then mapM_ go ts
+ else throwError (TypeArityError sr n arity (length ts))
+ | otherwise = throwError (RefError sr n)
+ go (S.TUnion ts) = traverse_ go ts
+ go (S.TyVar n)
+ | n `elem` argNames = return ()
+ | otherwise = throwError (RefError sr n)
+
+checkTypedefs :: Map Name S.AliasDef -> [S.TypeDef] -> TM [T.TypeDef]
+checkTypedefs aliases origdefs = do
+ let aliasArity = Map.map (\(S.AliasDef _ args _) -> length args) aliases
+ typeArity = Map.fromList [(n, length args)
+ | S.TypeDef (n, _) args _ <- origdefs]
+
+ let dups = Set.fromList (map (\(S.TypeDef (n, _) _ _) -> n) origdefs)
+ Set.\\ Map.keysSet typeArity
+ when (not (Set.null dups)) $
+ throwError (DupTypeError (Set.findMin dups))
+
+ let aliasdefs = [S.TypeDef name args typ
+ | S.AliasDef name args typ <- Map.elems aliases]
+
+ mapM_ (checkArity (aliasArity <> typeArity)) (aliasdefs ++ origdefs)
+ mapM (convertTypeDef aliases) origdefs
diff --git a/typecheck/CC/Typecheck/Types.hs b/typecheck/CC/Typecheck/Types.hs
new file mode 100644
index 0000000..3f3c471
--- /dev/null
+++ b/typecheck/CC/Typecheck/Types.hs
@@ -0,0 +1,103 @@
+module CC.Typecheck.Types where
+
+import Control.Monad.State.Strict
+import Control.Monad.Except
+import qualified Data.Map.Strict as Map
+import Data.Map.Strict (Map)
+import Data.Maybe (fromMaybe)
+import qualified Data.Set as Set
+import Data.Set (Set)
+
+import qualified CC.AST.Source as S
+import qualified CC.AST.Typed as T
+import CC.Pretty
+import CC.Types
+
+
+data TCError = TypeError SourceRange T.Type T.Type
+ | RefError SourceRange Name
+ | TypeArityError SourceRange Name Int Int
+ | DupTypeError Name
+ deriving (Show)
+
+instance Pretty TCError where
+ pretty (TypeError sr real expect) =
+ "Type error: Expression at " ++ pretty sr ++
+ " has type " ++ pretty real ++
+ ", but should have type " ++ pretty expect
+ pretty (RefError sr name) =
+ "Reference error: Variable '" ++ name ++ "' out of scope at " ++ pretty sr
+ pretty (TypeArityError sr name wanted got) =
+ "Type error: Type '" ++ name ++ "' has arity " ++ show wanted ++
+ " but gets " ++ show got ++ " type arguments at " ++ pretty sr
+ pretty (DupTypeError name) =
+ "Duplicate types: Type '" ++ name ++ "' defined multiple times"
+
+type TM a = ExceptT TCError (State Int) a
+
+genId :: TM Int
+genId = state (\idval -> (idval, idval + 1))
+
+genTyVar :: TM T.Type
+genTyVar = T.TyVar T.Instantiable <$> genId
+
+runTM :: TM a -> Either TCError a
+runTM m = evalState (runExceptT m) 1
+
+
+convertType :: Map Name S.AliasDef -> SourceRange -> S.Type -> TM T.Type
+convertType aliases sr = fmap snd . convertType' aliases mempty sr
+
+convertTypeDef :: Map Name S.AliasDef -> S.TypeDef -> TM T.TypeDef
+convertTypeDef aliases (S.TypeDef (name, sr) args ty) = do
+ (mapping, ty') <- convertType' aliases (Set.fromList (map fst args)) sr ty
+ let args' = [mapping Map.! n | (n, _) <- args]
+ return (T.TypeDef name args' ty')
+
+convertType' :: Map Name S.AliasDef -> Set Name -> SourceRange -> S.Type -> TM (Map Name Int, T.Type)
+convertType' aliases extraVars sr origtype = do
+ rewritten <- rewrite origtype
+ let frees = Set.toList (extraVars <> freeVars rewritten)
+ nums <- traverse (const genId) frees
+ let mapping = Map.fromList (zip frees nums)
+ return (mapping, convert mapping rewritten)
+ where
+ rewrite :: S.Type -> TM S.Type
+ rewrite (S.TFun t1 t2) = S.TFun <$> rewrite t1 <*> rewrite t2
+ rewrite S.TInt = return S.TInt
+ rewrite (S.TTup ts) = S.TTup <$> mapM rewrite ts
+ rewrite (S.TNamed n ts)
+ | Just (S.AliasDef _ args typ) <- Map.lookup n aliases =
+ if length args == length ts
+ then rewrite (subst (Map.fromList (zip (map fst args) ts)) typ)
+ else throwError (TypeArityError sr n (length args) (length ts))
+ | otherwise =
+ S.TNamed n <$> mapM rewrite ts
+ rewrite (S.TUnion ts) = S.TUnion . Set.fromList <$> mapM rewrite (Set.toList ts)
+ rewrite (S.TyVar n) = return (S.TyVar n)
+
+ -- Substitute type variables
+ subst :: Map Name S.Type -> S.Type -> S.Type
+ subst mp (S.TFun t1 t2) = S.TFun (subst mp t1) (subst mp t2)
+ subst _ S.TInt = S.TInt
+ subst mp (S.TTup ts) = S.TTup (map (subst mp) ts)
+ subst mp (S.TNamed n ts) = S.TNamed n (map (subst mp) ts)
+ subst mp (S.TUnion ts) = S.TUnion (Set.map (subst mp) ts)
+ subst mp orig@(S.TyVar n) = fromMaybe orig (Map.lookup n mp)
+
+ freeVars :: S.Type -> Set Name
+ freeVars (S.TFun t1 t2) = freeVars t1 <> freeVars t2
+ freeVars S.TInt = mempty
+ freeVars (S.TTup ts) = Set.unions (map freeVars ts)
+ freeVars (S.TNamed _ ts) = Set.unions (map freeVars ts)
+ freeVars (S.TUnion ts) = Set.unions (map freeVars (Set.toList ts))
+ freeVars (S.TyVar n) = Set.singleton n
+
+ convert :: Map Name Int -> S.Type -> T.Type
+ convert mp (S.TFun t1 t2) = T.TFun (convert mp t1) (convert mp t2)
+ convert _ S.TInt = T.TInt
+ convert mp (S.TTup ts) = T.TTup (map (convert mp) ts)
+ convert mp (S.TNamed n ts) = T.TNamed n (map (convert mp) ts)
+ convert mp (S.TUnion ts) = T.TUnion (Set.map (convert mp) ts)
+ -- TODO: Should this be Rigid? I really don't know how this works.
+ convert mp (S.TyVar n) = T.TyVar T.Rigid (mp Map.! n)