diff options
Diffstat (limited to 'typecheck')
-rw-r--r-- | typecheck/CC/Typecheck.hs | 182 | ||||
-rw-r--r-- | typecheck/CC/Typecheck/Typedefs.hs | 50 | ||||
-rw-r--r-- | typecheck/CC/Typecheck/Types.hs | 103 |
3 files changed, 268 insertions, 67 deletions
diff --git a/typecheck/CC/Typecheck.hs b/typecheck/CC/Typecheck.hs index f61103e..824a714 100644 --- a/typecheck/CC/Typecheck.hs +++ b/typecheck/CC/Typecheck.hs @@ -1,4 +1,5 @@ {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeSynonymInstances #-} module CC.Typecheck(runPass) where @@ -10,59 +11,47 @@ import Data.Maybe import qualified Data.Set as Set import Data.Set (Set) +import Debug.Trace + import qualified CC.AST.Source as S import qualified CC.AST.Typed as T import CC.Context -import CC.Pretty import CC.Types +import CC.Typecheck.Typedefs +import CC.Typecheck.Types -- Inspiration: https://github.com/kritzcreek/fby19 -data TCError = TypeError SourceRange T.Type T.Type - | RefError SourceRange Name +data Env = + Env (Map Name T.TypeScheme) -- Definitions in scope + (Map Name T.TypeDef) -- Type definitions + (Map Name S.AliasDef) -- Type aliases deriving (Show) -instance Pretty TCError where - pretty (TypeError sr real expect) = - "Type error: Expression at " ++ pretty sr ++ - " has type " ++ pretty real ++ - ", but should have type " ++ pretty expect - pretty (RefError sr name) = - "Reference error: Variable '" ++ name ++ "' out of scope at " ++ pretty sr - -type TM a = ExceptT TCError (State Int) a - -genId :: TM Int -genId = state (\idval -> (idval, idval + 1)) - -genTyVar :: TM T.Type -genTyVar = T.TyVar <$> genId - -runTM :: TM a -> Either TCError a -runTM m = evalState (runExceptT m) 1 - - -newtype Env = Env (Map Name T.TypeScheme) - newtype Subst = Subst (Map Int T.Type) class FreeTypeVars a where - freeTypeVars :: a -> Set Int + -- Free instantiable type variables + freeInstTypeVars :: a -> Set Int instance FreeTypeVars T.Type where - freeTypeVars (T.TFun t1 t2) = freeTypeVars t1 <> freeTypeVars t2 - freeTypeVars T.TInt = mempty - freeTypeVars (T.TTup ts) = Set.unions (map freeTypeVars ts) - freeTypeVars (T.TyVar var) = Set.singleton var + freeInstTypeVars (T.TFun t1 t2) = freeInstTypeVars t1 <> freeInstTypeVars t2 + freeInstTypeVars T.TInt = mempty + freeInstTypeVars (T.TTup ts) = Set.unions (map freeInstTypeVars ts) + freeInstTypeVars (T.TNamed _ ts) = Set.unions (map freeInstTypeVars ts) + freeInstTypeVars (T.TUnion ts) = Set.unions (map freeInstTypeVars (Set.toList ts)) + freeInstTypeVars (T.TyVar T.Instantiable var) = Set.singleton var + freeInstTypeVars (T.TyVar T.Rigid _) = mempty instance FreeTypeVars T.TypeScheme where - freeTypeVars (T.TypeScheme bnds ty) = foldr Set.delete (freeTypeVars ty) bnds + freeInstTypeVars (T.TypeScheme bnds ty) = + foldr Set.delete (freeInstTypeVars ty) bnds instance FreeTypeVars Env where - freeTypeVars (Env mp) = foldMap freeTypeVars (Map.elems mp) + freeInstTypeVars (Env mp _ _) = foldMap freeInstTypeVars (Map.elems mp) infixr >>! @@ -74,14 +63,20 @@ instance Substitute T.Type where T.TFun t1 t2 -> T.TFun (theta >>! t1) (theta >>! t2) T.TInt -> T.TInt T.TTup ts -> T.TTup (map (theta >>!) ts) - T.TyVar i -> fromMaybe ty (Map.lookup i mp) + T.TNamed n ts -> T.TNamed n (map (theta >>!) ts) + T.TUnion ts -> T.TUnion (Set.map (theta >>!) ts) + T.TyVar T.Instantiable i -> fromMaybe ty (Map.lookup i mp) + T.TyVar T.Rigid i + | i `Map.member` mp -> error "Attempt to substitute a rigid type variable" + | otherwise -> ty instance Substitute T.TypeScheme where Subst mp >>! T.TypeScheme bnds ty = T.TypeScheme bnds (Subst (foldr Map.delete mp bnds) >>! ty) instance Substitute Env where - theta >>! Env mp = Env (Map.map (theta >>!) mp) + theta >>! Env mp tdefs aliases = + Env (Map.map (theta >>!) mp) tdefs aliases -- TODO: make this instance unnecessary instance Substitute T.Expr where @@ -94,6 +89,7 @@ instance Substitute T.Expr where _ >>! expr@(T.Int _) = expr theta >>! T.Tup es = T.Tup (map (theta >>!) es) theta >>! T.Var (T.Occ name ty) = T.Var (T.Occ name (theta >>! ty)) + theta >>! T.Constr ty n = T.Constr (theta >>! ty) n instance Semigroup Subst where @@ -103,20 +99,43 @@ instance Monoid Subst where mempty = Subst mempty emptyEnv :: Env -emptyEnv = Env mempty +emptyEnv = Env mempty mempty mempty + +envAddDef :: Name -> T.TypeScheme -> Env -> Env +envAddDef name sty (Env mp tmp aliases) + | name `Map.member` mp = error "envAddDef on name already in environment" + | otherwise = + Env (Map.insert name sty mp) tmp aliases + +envFindDef :: Name -> Env -> Maybe T.TypeScheme +envFindDef name (Env mp _ _) = Map.lookup name mp + +envAddTypes :: Map Name T.TypeDef -> Env -> Env +envAddTypes l (Env mp tdefs aliases) = + let combined = l <> tdefs + in if Map.size combined == Map.size l + Map.size tdefs + then Env mp combined aliases + else error "envAddTypes on duplicate type names" + +envFindType :: Name -> Env -> Maybe T.TypeDef +envFindType name (Env _ tdefs _) = Map.lookup name tdefs -envAdd :: Name -> T.TypeScheme -> Env -> Env -envAdd name sty (Env mp) = Env (Map.insert name sty mp) +envAddAliases :: Map Name S.AliasDef -> Env -> Env +envAddAliases l (Env mp tdefs aliases) = + let combined = l <> aliases + in if Map.size combined == Map.size l + Map.size aliases + then Env mp tdefs combined + else error "envAddAliaes on duplicate type names" -envFind :: Name -> Env -> Maybe T.TypeScheme -envFind name (Env mp) = Map.lookup name mp +envAliases :: Env -> Map Name S.AliasDef +envAliases (Env _ _ aliases) = aliases substVar :: Int -> T.Type -> Subst substVar var ty = Subst (Map.singleton var ty) generalise :: Env -> T.Type -> T.TypeScheme generalise env ty = - T.TypeScheme (Set.toList (freeTypeVars ty Set.\\ freeTypeVars env)) ty + T.TypeScheme (Set.toList (freeInstTypeVars ty Set.\\ freeInstTypeVars env)) ty instantiate :: T.TypeScheme -> TM T.Type instantiate (T.TypeScheme bnds ty) = do @@ -124,6 +143,9 @@ instantiate (T.TypeScheme bnds ty) = do let theta = Subst (Map.fromList (zip bnds vars)) return (theta >>! ty) +freshenFrees :: Env -> T.Type -> TM T.Type +freshenFrees env = instantiate . generalise env + data UnifyContext = UnifyContext SourceRange T.Type T.Type unify :: SourceRange -> T.Type -> T.Type -> TM Subst @@ -131,25 +153,22 @@ unify sr t1 t2 = unify' (UnifyContext sr t1 t2) t1 t2 unify' :: UnifyContext -> T.Type -> T.Type -> TM Subst unify' _ T.TInt T.TInt = return mempty -unify' ctx (T.TFun t1 t2) (T.TFun u1 u2) = (<>) <$> unify' ctx t1 u1 <*> unify' ctx t2 u2 +unify' ctx (T.TFun t1 t2) (T.TFun u1 u2) = + (<>) <$> unify' ctx t1 u1 <*> unify' ctx t2 u2 unify' ctx (T.TTup ts) (T.TTup us) | length ts == length us = mconcat <$> zipWithM (unify' ctx) ts us -unify' _ (T.TyVar var) ty = return (substVar var ty) -unify' _ ty (T.TyVar var) = return (substVar var ty) +unify' _ (T.TyVar T.Instantiable var) ty = return (substVar var ty) +unify' _ ty (T.TyVar T.Instantiable var) = return (substVar var ty) +-- TODO: fix unify unify' (UnifyContext sr t1 t2) _ _ = throwError (TypeError sr t1 t2) -convertType :: S.Type -> T.Type -convertType (S.TFun t1 t2) = T.TFun (convertType t1) (convertType t2) -convertType S.TInt = T.TInt -convertType (S.TTup ts) = T.TTup (map convertType ts) - infer :: Env -> S.Expr -> TM (Subst, T.Expr) infer env expr = case expr of S.Lam _ [] body -> infer env body S.Lam sr args@(_:_:_) body -> infer env (foldr (S.Lam sr . pure) body args) S.Lam _ [(arg, _)] body -> do argVar <- genTyVar - let augEnv = envAdd arg (T.TypeScheme [] argVar) env + let augEnv = envAddDef arg (T.TypeScheme [] argVar) env (theta, body') <- infer augEnv body let argType = theta >>! argVar return (theta, T.Lam (T.TFun argType (T.exprType body')) @@ -157,7 +176,7 @@ infer env expr = case expr of S.Let _ (name, _) rhs body -> do (theta1, rhs') <- infer env rhs let varType = T.exprType rhs' - let augEnv = envAdd name (T.TypeScheme [] varType) env + let augEnv = envAddDef name (T.TypeScheme [] varType) env (theta2, body') <- infer augEnv body return (theta2 <> theta1, T.Let (T.Occ name varType) rhs' body') S.Call sr func arg -> do @@ -173,14 +192,22 @@ infer env expr = case expr of S.Int _ val -> return (mempty, T.Int val) S.Tup _ es -> fmap T.Tup <$> inferList env es S.Var sr name - | Just sty <- envFind name env -> do + | Just sty <- envFindDef name env -> do ty <- instantiate sty return (mempty, T.Var (T.Occ name ty)) | otherwise -> throwError (RefError sr name) + S.Constr sr name -> case envFindType name env of + Just (T.TypeDef typname params typ) -> do + restyp <- freshenFrees emptyEnv + (T.TNamed typname (map (T.TyVar T.Instantiable) params)) + return (mempty, T.Constr (T.TFun typ restyp) name) + _ -> + throwError (RefError sr name) S.Annot sr subex ty -> do (theta1, subex') <- infer env subex - theta2 <- unify sr (T.exprType subex') (convertType ty) + ty' <- convertType (envAliases env) sr ty + theta2 <- unify sr (T.exprType subex') ty' return (theta2 <> theta1, theta2 >>! subex') -- TODO: quadratic complexity inferList :: Env -> [S.Expr] -> TM (Subst, [T.Expr]) @@ -192,23 +219,44 @@ inferList env (expr : exprs) = do runPass :: Context -> S.Program -> Either TCError T.Program -runPass (Context _ (Builtins builtins)) prog = - let env = Env (Map.map (generalise emptyEnv) builtins) +runPass (Context _ (Builtins builtins _)) prog = + let env = Env (Map.map (generalise emptyEnv) builtins) mempty mempty in runTM (typeCheck env prog) typeCheck :: Env -> S.Program -> TM T.Program -typeCheck startEnv (S.Program decls) = - let defs = [(name, ty) - | S.Def (S.Function (Just ty) (name, _) _ _) <- decls] - env = foldl (\env' (name, ty) -> envAdd name (generalise env' (convertType ty)) env') - startEnv defs - in T.Program <$> mapM (typeCheckDef env . (\(S.Def def) -> def)) decls - -typeCheckDef :: Env -> S.Def -> TM T.Def -typeCheckDef env (S.Function mannot (name, sr) args@(_:_) body) = - typeCheckDef env (S.Function mannot (name, sr) [] (S.Lam sr args body)) -typeCheckDef env (S.Function (Just annot) (name, sr) [] body) = - typeCheckDef env (S.Function Nothing (name, sr) [] (S.Annot sr body annot)) -typeCheckDef env (S.Function Nothing (name, _) [] body) = do +typeCheck startEnv (S.Program decls) = do + traceM (show decls) + + let aliasdefs = [(n, def) + | S.DeclAlias def@(S.AliasDef (n, _) _ _) <- decls] + env1 = envAddAliases (Map.fromList aliasdefs) startEnv + + typedefs' <- checkTypedefs (envAliases env1) [def | S.DeclType def <- decls] + let typedefsMap = Map.fromList [(n, def) | def@(T.TypeDef n _ _) <- typedefs'] + + let funcdefs = [def | S.DeclFunc def <- decls] + typedfuncs <- sequence + [(name,) <$> convertType (envAliases env1) sr ty + | S.FuncDef (Just ty) (name, sr) _ _ <- funcdefs] + + let env2 = envAddTypes typedefsMap env1 + + traceM (show typedefsMap) + + let env = foldl (\env' (name, ty) -> + envAddDef name (generalise env' ty) env') + env2 typedfuncs + + traceM (show env) + + funcdefs' <- mapM (typeCheckFunc env) funcdefs + return (T.Program funcdefs' typedefsMap) + +typeCheckFunc :: Env -> S.FuncDef -> TM T.Def +typeCheckFunc env (S.FuncDef mannot (name, sr) args@(_:_) body) = + typeCheckFunc env (S.FuncDef mannot (name, sr) [] (S.Lam sr args body)) +typeCheckFunc env (S.FuncDef (Just annot) (name, sr) [] body) = + typeCheckFunc env (S.FuncDef Nothing (name, sr) [] (S.Annot sr body annot)) +typeCheckFunc env (S.FuncDef Nothing (name, _) [] body) = do (_, body') <- infer env body return (T.Def name body') diff --git a/typecheck/CC/Typecheck/Typedefs.hs b/typecheck/CC/Typecheck/Typedefs.hs new file mode 100644 index 0000000..ad9bdd8 --- /dev/null +++ b/typecheck/CC/Typecheck/Typedefs.hs @@ -0,0 +1,50 @@ +module CC.Typecheck.Typedefs(checkTypedefs) where + +import Control.Monad.Except +import Data.Foldable (traverse_) +import qualified Data.Map.Strict as Map +import Data.Map.Strict (Map) +import qualified Data.Set as Set + +import qualified CC.AST.Source as S +import qualified CC.AST.Typed as T +import CC.Typecheck.Types +import CC.Types + + +checkArity :: Map Name Int -> S.TypeDef -> TM () +checkArity typeArity (S.TypeDef (_, sr) args ty) = go ty + where + argNames = map fst args -- probably a small list + + go :: S.Type -> TM () + go (S.TFun t1 t2) = go t1 >> go t2 + go S.TInt = return () + go (S.TTup ts) = mapM_ go ts + go (S.TNamed n ts) + | Just arity <- Map.lookup n typeArity = + if length ts == arity + then mapM_ go ts + else throwError (TypeArityError sr n arity (length ts)) + | otherwise = throwError (RefError sr n) + go (S.TUnion ts) = traverse_ go ts + go (S.TyVar n) + | n `elem` argNames = return () + | otherwise = throwError (RefError sr n) + +checkTypedefs :: Map Name S.AliasDef -> [S.TypeDef] -> TM [T.TypeDef] +checkTypedefs aliases origdefs = do + let aliasArity = Map.map (\(S.AliasDef _ args _) -> length args) aliases + typeArity = Map.fromList [(n, length args) + | S.TypeDef (n, _) args _ <- origdefs] + + let dups = Set.fromList (map (\(S.TypeDef (n, _) _ _) -> n) origdefs) + Set.\\ Map.keysSet typeArity + when (not (Set.null dups)) $ + throwError (DupTypeError (Set.findMin dups)) + + let aliasdefs = [S.TypeDef name args typ + | S.AliasDef name args typ <- Map.elems aliases] + + mapM_ (checkArity (aliasArity <> typeArity)) (aliasdefs ++ origdefs) + mapM (convertTypeDef aliases) origdefs diff --git a/typecheck/CC/Typecheck/Types.hs b/typecheck/CC/Typecheck/Types.hs new file mode 100644 index 0000000..3f3c471 --- /dev/null +++ b/typecheck/CC/Typecheck/Types.hs @@ -0,0 +1,103 @@ +module CC.Typecheck.Types where + +import Control.Monad.State.Strict +import Control.Monad.Except +import qualified Data.Map.Strict as Map +import Data.Map.Strict (Map) +import Data.Maybe (fromMaybe) +import qualified Data.Set as Set +import Data.Set (Set) + +import qualified CC.AST.Source as S +import qualified CC.AST.Typed as T +import CC.Pretty +import CC.Types + + +data TCError = TypeError SourceRange T.Type T.Type + | RefError SourceRange Name + | TypeArityError SourceRange Name Int Int + | DupTypeError Name + deriving (Show) + +instance Pretty TCError where + pretty (TypeError sr real expect) = + "Type error: Expression at " ++ pretty sr ++ + " has type " ++ pretty real ++ + ", but should have type " ++ pretty expect + pretty (RefError sr name) = + "Reference error: Variable '" ++ name ++ "' out of scope at " ++ pretty sr + pretty (TypeArityError sr name wanted got) = + "Type error: Type '" ++ name ++ "' has arity " ++ show wanted ++ + " but gets " ++ show got ++ " type arguments at " ++ pretty sr + pretty (DupTypeError name) = + "Duplicate types: Type '" ++ name ++ "' defined multiple times" + +type TM a = ExceptT TCError (State Int) a + +genId :: TM Int +genId = state (\idval -> (idval, idval + 1)) + +genTyVar :: TM T.Type +genTyVar = T.TyVar T.Instantiable <$> genId + +runTM :: TM a -> Either TCError a +runTM m = evalState (runExceptT m) 1 + + +convertType :: Map Name S.AliasDef -> SourceRange -> S.Type -> TM T.Type +convertType aliases sr = fmap snd . convertType' aliases mempty sr + +convertTypeDef :: Map Name S.AliasDef -> S.TypeDef -> TM T.TypeDef +convertTypeDef aliases (S.TypeDef (name, sr) args ty) = do + (mapping, ty') <- convertType' aliases (Set.fromList (map fst args)) sr ty + let args' = [mapping Map.! n | (n, _) <- args] + return (T.TypeDef name args' ty') + +convertType' :: Map Name S.AliasDef -> Set Name -> SourceRange -> S.Type -> TM (Map Name Int, T.Type) +convertType' aliases extraVars sr origtype = do + rewritten <- rewrite origtype + let frees = Set.toList (extraVars <> freeVars rewritten) + nums <- traverse (const genId) frees + let mapping = Map.fromList (zip frees nums) + return (mapping, convert mapping rewritten) + where + rewrite :: S.Type -> TM S.Type + rewrite (S.TFun t1 t2) = S.TFun <$> rewrite t1 <*> rewrite t2 + rewrite S.TInt = return S.TInt + rewrite (S.TTup ts) = S.TTup <$> mapM rewrite ts + rewrite (S.TNamed n ts) + | Just (S.AliasDef _ args typ) <- Map.lookup n aliases = + if length args == length ts + then rewrite (subst (Map.fromList (zip (map fst args) ts)) typ) + else throwError (TypeArityError sr n (length args) (length ts)) + | otherwise = + S.TNamed n <$> mapM rewrite ts + rewrite (S.TUnion ts) = S.TUnion . Set.fromList <$> mapM rewrite (Set.toList ts) + rewrite (S.TyVar n) = return (S.TyVar n) + + -- Substitute type variables + subst :: Map Name S.Type -> S.Type -> S.Type + subst mp (S.TFun t1 t2) = S.TFun (subst mp t1) (subst mp t2) + subst _ S.TInt = S.TInt + subst mp (S.TTup ts) = S.TTup (map (subst mp) ts) + subst mp (S.TNamed n ts) = S.TNamed n (map (subst mp) ts) + subst mp (S.TUnion ts) = S.TUnion (Set.map (subst mp) ts) + subst mp orig@(S.TyVar n) = fromMaybe orig (Map.lookup n mp) + + freeVars :: S.Type -> Set Name + freeVars (S.TFun t1 t2) = freeVars t1 <> freeVars t2 + freeVars S.TInt = mempty + freeVars (S.TTup ts) = Set.unions (map freeVars ts) + freeVars (S.TNamed _ ts) = Set.unions (map freeVars ts) + freeVars (S.TUnion ts) = Set.unions (map freeVars (Set.toList ts)) + freeVars (S.TyVar n) = Set.singleton n + + convert :: Map Name Int -> S.Type -> T.Type + convert mp (S.TFun t1 t2) = T.TFun (convert mp t1) (convert mp t2) + convert _ S.TInt = T.TInt + convert mp (S.TTup ts) = T.TTup (map (convert mp) ts) + convert mp (S.TNamed n ts) = T.TNamed n (map (convert mp) ts) + convert mp (S.TUnion ts) = T.TUnion (Set.map (convert mp) ts) + -- TODO: Should this be Rigid? I really don't know how this works. + convert mp (S.TyVar n) = T.TyVar T.Rigid (mp Map.! n) |