aboutsummaryrefslogtreecommitdiff
path: root/typecheck
diff options
context:
space:
mode:
Diffstat (limited to 'typecheck')
-rw-r--r--typecheck/CC/Typecheck.hs182
-rw-r--r--typecheck/CC/Typecheck/Typedefs.hs50
-rw-r--r--typecheck/CC/Typecheck/Types.hs103
3 files changed, 268 insertions, 67 deletions
diff --git a/typecheck/CC/Typecheck.hs b/typecheck/CC/Typecheck.hs
index f61103e..824a714 100644
--- a/typecheck/CC/Typecheck.hs
+++ b/typecheck/CC/Typecheck.hs
@@ -1,4 +1,5 @@
{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeSynonymInstances #-}
module CC.Typecheck(runPass) where
@@ -10,59 +11,47 @@ import Data.Maybe
import qualified Data.Set as Set
import Data.Set (Set)
+import Debug.Trace
+
import qualified CC.AST.Source as S
import qualified CC.AST.Typed as T
import CC.Context
-import CC.Pretty
import CC.Types
+import CC.Typecheck.Typedefs
+import CC.Typecheck.Types
-- Inspiration: https://github.com/kritzcreek/fby19
-data TCError = TypeError SourceRange T.Type T.Type
- | RefError SourceRange Name
+data Env =
+ Env (Map Name T.TypeScheme) -- Definitions in scope
+ (Map Name T.TypeDef) -- Type definitions
+ (Map Name S.AliasDef) -- Type aliases
deriving (Show)
-instance Pretty TCError where
- pretty (TypeError sr real expect) =
- "Type error: Expression at " ++ pretty sr ++
- " has type " ++ pretty real ++
- ", but should have type " ++ pretty expect
- pretty (RefError sr name) =
- "Reference error: Variable '" ++ name ++ "' out of scope at " ++ pretty sr
-
-type TM a = ExceptT TCError (State Int) a
-
-genId :: TM Int
-genId = state (\idval -> (idval, idval + 1))
-
-genTyVar :: TM T.Type
-genTyVar = T.TyVar <$> genId
-
-runTM :: TM a -> Either TCError a
-runTM m = evalState (runExceptT m) 1
-
-
-newtype Env = Env (Map Name T.TypeScheme)
-
newtype Subst = Subst (Map Int T.Type)
class FreeTypeVars a where
- freeTypeVars :: a -> Set Int
+ -- Free instantiable type variables
+ freeInstTypeVars :: a -> Set Int
instance FreeTypeVars T.Type where
- freeTypeVars (T.TFun t1 t2) = freeTypeVars t1 <> freeTypeVars t2
- freeTypeVars T.TInt = mempty
- freeTypeVars (T.TTup ts) = Set.unions (map freeTypeVars ts)
- freeTypeVars (T.TyVar var) = Set.singleton var
+ freeInstTypeVars (T.TFun t1 t2) = freeInstTypeVars t1 <> freeInstTypeVars t2
+ freeInstTypeVars T.TInt = mempty
+ freeInstTypeVars (T.TTup ts) = Set.unions (map freeInstTypeVars ts)
+ freeInstTypeVars (T.TNamed _ ts) = Set.unions (map freeInstTypeVars ts)
+ freeInstTypeVars (T.TUnion ts) = Set.unions (map freeInstTypeVars (Set.toList ts))
+ freeInstTypeVars (T.TyVar T.Instantiable var) = Set.singleton var
+ freeInstTypeVars (T.TyVar T.Rigid _) = mempty
instance FreeTypeVars T.TypeScheme where
- freeTypeVars (T.TypeScheme bnds ty) = foldr Set.delete (freeTypeVars ty) bnds
+ freeInstTypeVars (T.TypeScheme bnds ty) =
+ foldr Set.delete (freeInstTypeVars ty) bnds
instance FreeTypeVars Env where
- freeTypeVars (Env mp) = foldMap freeTypeVars (Map.elems mp)
+ freeInstTypeVars (Env mp _ _) = foldMap freeInstTypeVars (Map.elems mp)
infixr >>!
@@ -74,14 +63,20 @@ instance Substitute T.Type where
T.TFun t1 t2 -> T.TFun (theta >>! t1) (theta >>! t2)
T.TInt -> T.TInt
T.TTup ts -> T.TTup (map (theta >>!) ts)
- T.TyVar i -> fromMaybe ty (Map.lookup i mp)
+ T.TNamed n ts -> T.TNamed n (map (theta >>!) ts)
+ T.TUnion ts -> T.TUnion (Set.map (theta >>!) ts)
+ T.TyVar T.Instantiable i -> fromMaybe ty (Map.lookup i mp)
+ T.TyVar T.Rigid i
+ | i `Map.member` mp -> error "Attempt to substitute a rigid type variable"
+ | otherwise -> ty
instance Substitute T.TypeScheme where
Subst mp >>! T.TypeScheme bnds ty =
T.TypeScheme bnds (Subst (foldr Map.delete mp bnds) >>! ty)
instance Substitute Env where
- theta >>! Env mp = Env (Map.map (theta >>!) mp)
+ theta >>! Env mp tdefs aliases =
+ Env (Map.map (theta >>!) mp) tdefs aliases
-- TODO: make this instance unnecessary
instance Substitute T.Expr where
@@ -94,6 +89,7 @@ instance Substitute T.Expr where
_ >>! expr@(T.Int _) = expr
theta >>! T.Tup es = T.Tup (map (theta >>!) es)
theta >>! T.Var (T.Occ name ty) = T.Var (T.Occ name (theta >>! ty))
+ theta >>! T.Constr ty n = T.Constr (theta >>! ty) n
instance Semigroup Subst where
@@ -103,20 +99,43 @@ instance Monoid Subst where
mempty = Subst mempty
emptyEnv :: Env
-emptyEnv = Env mempty
+emptyEnv = Env mempty mempty mempty
+
+envAddDef :: Name -> T.TypeScheme -> Env -> Env
+envAddDef name sty (Env mp tmp aliases)
+ | name `Map.member` mp = error "envAddDef on name already in environment"
+ | otherwise =
+ Env (Map.insert name sty mp) tmp aliases
+
+envFindDef :: Name -> Env -> Maybe T.TypeScheme
+envFindDef name (Env mp _ _) = Map.lookup name mp
+
+envAddTypes :: Map Name T.TypeDef -> Env -> Env
+envAddTypes l (Env mp tdefs aliases) =
+ let combined = l <> tdefs
+ in if Map.size combined == Map.size l + Map.size tdefs
+ then Env mp combined aliases
+ else error "envAddTypes on duplicate type names"
+
+envFindType :: Name -> Env -> Maybe T.TypeDef
+envFindType name (Env _ tdefs _) = Map.lookup name tdefs
-envAdd :: Name -> T.TypeScheme -> Env -> Env
-envAdd name sty (Env mp) = Env (Map.insert name sty mp)
+envAddAliases :: Map Name S.AliasDef -> Env -> Env
+envAddAliases l (Env mp tdefs aliases) =
+ let combined = l <> aliases
+ in if Map.size combined == Map.size l + Map.size aliases
+ then Env mp tdefs combined
+ else error "envAddAliaes on duplicate type names"
-envFind :: Name -> Env -> Maybe T.TypeScheme
-envFind name (Env mp) = Map.lookup name mp
+envAliases :: Env -> Map Name S.AliasDef
+envAliases (Env _ _ aliases) = aliases
substVar :: Int -> T.Type -> Subst
substVar var ty = Subst (Map.singleton var ty)
generalise :: Env -> T.Type -> T.TypeScheme
generalise env ty =
- T.TypeScheme (Set.toList (freeTypeVars ty Set.\\ freeTypeVars env)) ty
+ T.TypeScheme (Set.toList (freeInstTypeVars ty Set.\\ freeInstTypeVars env)) ty
instantiate :: T.TypeScheme -> TM T.Type
instantiate (T.TypeScheme bnds ty) = do
@@ -124,6 +143,9 @@ instantiate (T.TypeScheme bnds ty) = do
let theta = Subst (Map.fromList (zip bnds vars))
return (theta >>! ty)
+freshenFrees :: Env -> T.Type -> TM T.Type
+freshenFrees env = instantiate . generalise env
+
data UnifyContext = UnifyContext SourceRange T.Type T.Type
unify :: SourceRange -> T.Type -> T.Type -> TM Subst
@@ -131,25 +153,22 @@ unify sr t1 t2 = unify' (UnifyContext sr t1 t2) t1 t2
unify' :: UnifyContext -> T.Type -> T.Type -> TM Subst
unify' _ T.TInt T.TInt = return mempty
-unify' ctx (T.TFun t1 t2) (T.TFun u1 u2) = (<>) <$> unify' ctx t1 u1 <*> unify' ctx t2 u2
+unify' ctx (T.TFun t1 t2) (T.TFun u1 u2) =
+ (<>) <$> unify' ctx t1 u1 <*> unify' ctx t2 u2
unify' ctx (T.TTup ts) (T.TTup us)
| length ts == length us = mconcat <$> zipWithM (unify' ctx) ts us
-unify' _ (T.TyVar var) ty = return (substVar var ty)
-unify' _ ty (T.TyVar var) = return (substVar var ty)
+unify' _ (T.TyVar T.Instantiable var) ty = return (substVar var ty)
+unify' _ ty (T.TyVar T.Instantiable var) = return (substVar var ty)
+-- TODO: fix unify
unify' (UnifyContext sr t1 t2) _ _ = throwError (TypeError sr t1 t2)
-convertType :: S.Type -> T.Type
-convertType (S.TFun t1 t2) = T.TFun (convertType t1) (convertType t2)
-convertType S.TInt = T.TInt
-convertType (S.TTup ts) = T.TTup (map convertType ts)
-
infer :: Env -> S.Expr -> TM (Subst, T.Expr)
infer env expr = case expr of
S.Lam _ [] body -> infer env body
S.Lam sr args@(_:_:_) body -> infer env (foldr (S.Lam sr . pure) body args)
S.Lam _ [(arg, _)] body -> do
argVar <- genTyVar
- let augEnv = envAdd arg (T.TypeScheme [] argVar) env
+ let augEnv = envAddDef arg (T.TypeScheme [] argVar) env
(theta, body') <- infer augEnv body
let argType = theta >>! argVar
return (theta, T.Lam (T.TFun argType (T.exprType body'))
@@ -157,7 +176,7 @@ infer env expr = case expr of
S.Let _ (name, _) rhs body -> do
(theta1, rhs') <- infer env rhs
let varType = T.exprType rhs'
- let augEnv = envAdd name (T.TypeScheme [] varType) env
+ let augEnv = envAddDef name (T.TypeScheme [] varType) env
(theta2, body') <- infer augEnv body
return (theta2 <> theta1, T.Let (T.Occ name varType) rhs' body')
S.Call sr func arg -> do
@@ -173,14 +192,22 @@ infer env expr = case expr of
S.Int _ val -> return (mempty, T.Int val)
S.Tup _ es -> fmap T.Tup <$> inferList env es
S.Var sr name
- | Just sty <- envFind name env -> do
+ | Just sty <- envFindDef name env -> do
ty <- instantiate sty
return (mempty, T.Var (T.Occ name ty))
| otherwise ->
throwError (RefError sr name)
+ S.Constr sr name -> case envFindType name env of
+ Just (T.TypeDef typname params typ) -> do
+ restyp <- freshenFrees emptyEnv
+ (T.TNamed typname (map (T.TyVar T.Instantiable) params))
+ return (mempty, T.Constr (T.TFun typ restyp) name)
+ _ ->
+ throwError (RefError sr name)
S.Annot sr subex ty -> do
(theta1, subex') <- infer env subex
- theta2 <- unify sr (T.exprType subex') (convertType ty)
+ ty' <- convertType (envAliases env) sr ty
+ theta2 <- unify sr (T.exprType subex') ty'
return (theta2 <> theta1, theta2 >>! subex') -- TODO: quadratic complexity
inferList :: Env -> [S.Expr] -> TM (Subst, [T.Expr])
@@ -192,23 +219,44 @@ inferList env (expr : exprs) = do
runPass :: Context -> S.Program -> Either TCError T.Program
-runPass (Context _ (Builtins builtins)) prog =
- let env = Env (Map.map (generalise emptyEnv) builtins)
+runPass (Context _ (Builtins builtins _)) prog =
+ let env = Env (Map.map (generalise emptyEnv) builtins) mempty mempty
in runTM (typeCheck env prog)
typeCheck :: Env -> S.Program -> TM T.Program
-typeCheck startEnv (S.Program decls) =
- let defs = [(name, ty)
- | S.Def (S.Function (Just ty) (name, _) _ _) <- decls]
- env = foldl (\env' (name, ty) -> envAdd name (generalise env' (convertType ty)) env')
- startEnv defs
- in T.Program <$> mapM (typeCheckDef env . (\(S.Def def) -> def)) decls
-
-typeCheckDef :: Env -> S.Def -> TM T.Def
-typeCheckDef env (S.Function mannot (name, sr) args@(_:_) body) =
- typeCheckDef env (S.Function mannot (name, sr) [] (S.Lam sr args body))
-typeCheckDef env (S.Function (Just annot) (name, sr) [] body) =
- typeCheckDef env (S.Function Nothing (name, sr) [] (S.Annot sr body annot))
-typeCheckDef env (S.Function Nothing (name, _) [] body) = do
+typeCheck startEnv (S.Program decls) = do
+ traceM (show decls)
+
+ let aliasdefs = [(n, def)
+ | S.DeclAlias def@(S.AliasDef (n, _) _ _) <- decls]
+ env1 = envAddAliases (Map.fromList aliasdefs) startEnv
+
+ typedefs' <- checkTypedefs (envAliases env1) [def | S.DeclType def <- decls]
+ let typedefsMap = Map.fromList [(n, def) | def@(T.TypeDef n _ _) <- typedefs']
+
+ let funcdefs = [def | S.DeclFunc def <- decls]
+ typedfuncs <- sequence
+ [(name,) <$> convertType (envAliases env1) sr ty
+ | S.FuncDef (Just ty) (name, sr) _ _ <- funcdefs]
+
+ let env2 = envAddTypes typedefsMap env1
+
+ traceM (show typedefsMap)
+
+ let env = foldl (\env' (name, ty) ->
+ envAddDef name (generalise env' ty) env')
+ env2 typedfuncs
+
+ traceM (show env)
+
+ funcdefs' <- mapM (typeCheckFunc env) funcdefs
+ return (T.Program funcdefs' typedefsMap)
+
+typeCheckFunc :: Env -> S.FuncDef -> TM T.Def
+typeCheckFunc env (S.FuncDef mannot (name, sr) args@(_:_) body) =
+ typeCheckFunc env (S.FuncDef mannot (name, sr) [] (S.Lam sr args body))
+typeCheckFunc env (S.FuncDef (Just annot) (name, sr) [] body) =
+ typeCheckFunc env (S.FuncDef Nothing (name, sr) [] (S.Annot sr body annot))
+typeCheckFunc env (S.FuncDef Nothing (name, _) [] body) = do
(_, body') <- infer env body
return (T.Def name body')
diff --git a/typecheck/CC/Typecheck/Typedefs.hs b/typecheck/CC/Typecheck/Typedefs.hs
new file mode 100644
index 0000000..ad9bdd8
--- /dev/null
+++ b/typecheck/CC/Typecheck/Typedefs.hs
@@ -0,0 +1,50 @@
+module CC.Typecheck.Typedefs(checkTypedefs) where
+
+import Control.Monad.Except
+import Data.Foldable (traverse_)
+import qualified Data.Map.Strict as Map
+import Data.Map.Strict (Map)
+import qualified Data.Set as Set
+
+import qualified CC.AST.Source as S
+import qualified CC.AST.Typed as T
+import CC.Typecheck.Types
+import CC.Types
+
+
+checkArity :: Map Name Int -> S.TypeDef -> TM ()
+checkArity typeArity (S.TypeDef (_, sr) args ty) = go ty
+ where
+ argNames = map fst args -- probably a small list
+
+ go :: S.Type -> TM ()
+ go (S.TFun t1 t2) = go t1 >> go t2
+ go S.TInt = return ()
+ go (S.TTup ts) = mapM_ go ts
+ go (S.TNamed n ts)
+ | Just arity <- Map.lookup n typeArity =
+ if length ts == arity
+ then mapM_ go ts
+ else throwError (TypeArityError sr n arity (length ts))
+ | otherwise = throwError (RefError sr n)
+ go (S.TUnion ts) = traverse_ go ts
+ go (S.TyVar n)
+ | n `elem` argNames = return ()
+ | otherwise = throwError (RefError sr n)
+
+checkTypedefs :: Map Name S.AliasDef -> [S.TypeDef] -> TM [T.TypeDef]
+checkTypedefs aliases origdefs = do
+ let aliasArity = Map.map (\(S.AliasDef _ args _) -> length args) aliases
+ typeArity = Map.fromList [(n, length args)
+ | S.TypeDef (n, _) args _ <- origdefs]
+
+ let dups = Set.fromList (map (\(S.TypeDef (n, _) _ _) -> n) origdefs)
+ Set.\\ Map.keysSet typeArity
+ when (not (Set.null dups)) $
+ throwError (DupTypeError (Set.findMin dups))
+
+ let aliasdefs = [S.TypeDef name args typ
+ | S.AliasDef name args typ <- Map.elems aliases]
+
+ mapM_ (checkArity (aliasArity <> typeArity)) (aliasdefs ++ origdefs)
+ mapM (convertTypeDef aliases) origdefs
diff --git a/typecheck/CC/Typecheck/Types.hs b/typecheck/CC/Typecheck/Types.hs
new file mode 100644
index 0000000..3f3c471
--- /dev/null
+++ b/typecheck/CC/Typecheck/Types.hs
@@ -0,0 +1,103 @@
+module CC.Typecheck.Types where
+
+import Control.Monad.State.Strict
+import Control.Monad.Except
+import qualified Data.Map.Strict as Map
+import Data.Map.Strict (Map)
+import Data.Maybe (fromMaybe)
+import qualified Data.Set as Set
+import Data.Set (Set)
+
+import qualified CC.AST.Source as S
+import qualified CC.AST.Typed as T
+import CC.Pretty
+import CC.Types
+
+
+data TCError = TypeError SourceRange T.Type T.Type
+ | RefError SourceRange Name
+ | TypeArityError SourceRange Name Int Int
+ | DupTypeError Name
+ deriving (Show)
+
+instance Pretty TCError where
+ pretty (TypeError sr real expect) =
+ "Type error: Expression at " ++ pretty sr ++
+ " has type " ++ pretty real ++
+ ", but should have type " ++ pretty expect
+ pretty (RefError sr name) =
+ "Reference error: Variable '" ++ name ++ "' out of scope at " ++ pretty sr
+ pretty (TypeArityError sr name wanted got) =
+ "Type error: Type '" ++ name ++ "' has arity " ++ show wanted ++
+ " but gets " ++ show got ++ " type arguments at " ++ pretty sr
+ pretty (DupTypeError name) =
+ "Duplicate types: Type '" ++ name ++ "' defined multiple times"
+
+type TM a = ExceptT TCError (State Int) a
+
+genId :: TM Int
+genId = state (\idval -> (idval, idval + 1))
+
+genTyVar :: TM T.Type
+genTyVar = T.TyVar T.Instantiable <$> genId
+
+runTM :: TM a -> Either TCError a
+runTM m = evalState (runExceptT m) 1
+
+
+convertType :: Map Name S.AliasDef -> SourceRange -> S.Type -> TM T.Type
+convertType aliases sr = fmap snd . convertType' aliases mempty sr
+
+convertTypeDef :: Map Name S.AliasDef -> S.TypeDef -> TM T.TypeDef
+convertTypeDef aliases (S.TypeDef (name, sr) args ty) = do
+ (mapping, ty') <- convertType' aliases (Set.fromList (map fst args)) sr ty
+ let args' = [mapping Map.! n | (n, _) <- args]
+ return (T.TypeDef name args' ty')
+
+convertType' :: Map Name S.AliasDef -> Set Name -> SourceRange -> S.Type -> TM (Map Name Int, T.Type)
+convertType' aliases extraVars sr origtype = do
+ rewritten <- rewrite origtype
+ let frees = Set.toList (extraVars <> freeVars rewritten)
+ nums <- traverse (const genId) frees
+ let mapping = Map.fromList (zip frees nums)
+ return (mapping, convert mapping rewritten)
+ where
+ rewrite :: S.Type -> TM S.Type
+ rewrite (S.TFun t1 t2) = S.TFun <$> rewrite t1 <*> rewrite t2
+ rewrite S.TInt = return S.TInt
+ rewrite (S.TTup ts) = S.TTup <$> mapM rewrite ts
+ rewrite (S.TNamed n ts)
+ | Just (S.AliasDef _ args typ) <- Map.lookup n aliases =
+ if length args == length ts
+ then rewrite (subst (Map.fromList (zip (map fst args) ts)) typ)
+ else throwError (TypeArityError sr n (length args) (length ts))
+ | otherwise =
+ S.TNamed n <$> mapM rewrite ts
+ rewrite (S.TUnion ts) = S.TUnion . Set.fromList <$> mapM rewrite (Set.toList ts)
+ rewrite (S.TyVar n) = return (S.TyVar n)
+
+ -- Substitute type variables
+ subst :: Map Name S.Type -> S.Type -> S.Type
+ subst mp (S.TFun t1 t2) = S.TFun (subst mp t1) (subst mp t2)
+ subst _ S.TInt = S.TInt
+ subst mp (S.TTup ts) = S.TTup (map (subst mp) ts)
+ subst mp (S.TNamed n ts) = S.TNamed n (map (subst mp) ts)
+ subst mp (S.TUnion ts) = S.TUnion (Set.map (subst mp) ts)
+ subst mp orig@(S.TyVar n) = fromMaybe orig (Map.lookup n mp)
+
+ freeVars :: S.Type -> Set Name
+ freeVars (S.TFun t1 t2) = freeVars t1 <> freeVars t2
+ freeVars S.TInt = mempty
+ freeVars (S.TTup ts) = Set.unions (map freeVars ts)
+ freeVars (S.TNamed _ ts) = Set.unions (map freeVars ts)
+ freeVars (S.TUnion ts) = Set.unions (map freeVars (Set.toList ts))
+ freeVars (S.TyVar n) = Set.singleton n
+
+ convert :: Map Name Int -> S.Type -> T.Type
+ convert mp (S.TFun t1 t2) = T.TFun (convert mp t1) (convert mp t2)
+ convert _ S.TInt = T.TInt
+ convert mp (S.TTup ts) = T.TTup (map (convert mp) ts)
+ convert mp (S.TNamed n ts) = T.TNamed n (map (convert mp) ts)
+ convert mp (S.TUnion ts) = T.TUnion (Set.map (convert mp) ts)
+ -- TODO: Should this be Rigid? I really don't know how this works.
+ convert mp (S.TyVar n) = T.TyVar T.Rigid (mp Map.! n)