diff options
author | Tom Smeding <tom.smeding@gmail.com> | 2020-07-24 18:34:28 +0200 |
---|---|---|
committer | Tom Smeding <tom.smeding@gmail.com> | 2020-07-24 19:02:21 +0200 |
commit | 5cb19c0df3f838965c1100731f9118f28f50796f (patch) | |
tree | 89a2cbfeed5de5f9a26238321610c3fd58e533b9 | |
parent | a9134688a2132c8f9abfff206f6e30614bb9aeff (diff) |
Working basic type checker using Algorithm W
-rw-r--r-- | typecheck/CC/Typecheck.hs | 219 | ||||
-rw-r--r-- | utils/CC/IdSupply.hs | 29 |
2 files changed, 162 insertions, 86 deletions
diff --git a/typecheck/CC/Typecheck.hs b/typecheck/CC/Typecheck.hs index 47f42e3..6c9a22b 100644 --- a/typecheck/CC/Typecheck.hs +++ b/typecheck/CC/Typecheck.hs @@ -1,83 +1,188 @@ +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TypeSynonymInstances #-} module CC.Typecheck(runPass) where import Control.Monad.State.Strict -import Data.List (intersect) +import Control.Monad.Except import qualified Data.Map.Strict as Map +import Data.Map.Strict (Map) +import Data.Maybe +import qualified Data.Set as Set +import Data.Set (Set) import CC.Pretty import CC.Source import CC.Typed -data TypeError = TypeError SourceRange TypeT TypeT +-- Inspiration: https://github.com/kritzcreek/fby19 + + +data TCError = TypeError SourceRange TypeT TypeT + | RefError SourceRange Name deriving (Show) -instance Pretty TypeError where +instance Pretty TCError where pretty (TypeError sr real expect) = - "Type error: Expression at " ++ pretty sr ++ " has type " ++ pretty real ++ + "Type error: Expression at " ++ pretty sr ++ + " has type " ++ pretty real ++ ", but should have type " ++ pretty expect + pretty (RefError sr name) = + "Reference error: Variable '" ++ name ++ "' out of scope at " ++ pretty sr -type IdSupplyT m a = StateT Int m a +type TM a = ExceptT TCError (State Int) a -genId :: Monad m => IdSupplyT m Int +genId :: TM Int genId = state (\idval -> (idval, idval + 1)) -genTyVar :: Monad m => IdSupplyT m TypeT +genTyVar :: TM TypeT genTyVar = TyVar <$> genId -type TM a = IdSupplyT (Either TypeError) a +runTM :: TM a -> Either TCError a +runTM m = evalState (runExceptT m) 1 + + +data Scheme = Scheme [Int] TypeT + +newtype Env = Env (Map Name Scheme) + +newtype Subst = Subst (Map Int TypeT) + + +class FreeTypeVars a where + freeTypeVars :: a -> Set Int + +instance FreeTypeVars TypeT where + freeTypeVars (TFunT t1 t2) = freeTypeVars t1 <> freeTypeVars t2 + freeTypeVars TIntT = mempty + freeTypeVars (TyVar var) = Set.singleton var + +instance FreeTypeVars Scheme where + freeTypeVars (Scheme bnds ty) = foldr Set.delete (freeTypeVars ty) bnds + +instance FreeTypeVars Env where + freeTypeVars (Env mp) = foldMap freeTypeVars (Map.elems mp) + + +infixr >>! +class Substitute a where + (>>!) :: Subst -> a -> a + +instance Substitute TypeT where + theta@(Subst mp) >>! ty = case ty of + TFunT t1 t2 -> TFunT (theta >>! t1) (theta >>! t2) + TIntT -> TIntT + TyVar i -> fromMaybe ty (Map.lookup i mp) + +instance Substitute Scheme where + Subst mp >>! Scheme bnds ty = + Scheme bnds (Subst (foldr Map.delete mp bnds) >>! ty) + +instance Substitute Env where + theta >>! Env mp = Env (Map.map (theta >>!) mp) + +-- TODO: make this instance unnecessary +instance Substitute ExprT where + theta >>! LamT ty (Occ name ty2) body = + LamT (theta >>! ty) (Occ name (theta >>! ty2)) (theta >>! body) + theta >>! CallT ty e1 e2 = + CallT (theta >>! ty) (theta >>! e1) (theta >>! e2) + _ >>! expr@(IntT _) = expr + theta >>! VarT (Occ name ty) = VarT (Occ name (theta >>! ty)) + + +instance Semigroup Subst where + s2@(Subst m2) <> Subst m1 = Subst (Map.union (Map.map (s2 >>!) m1) m2) + +instance Monoid Subst where + mempty = Subst mempty + +emptyEnv :: Env +emptyEnv = Env mempty + +envAdd :: Name -> Scheme -> Env -> Env +envAdd name sty (Env mp) = Env (Map.insert name sty mp) + +envFind :: Name -> Env -> Maybe Scheme +envFind name (Env mp) = Map.lookup name mp + +substVar :: Int -> TypeT -> Subst +substVar var ty = Subst (Map.singleton var ty) + +generalise :: Env -> TypeT -> Scheme +generalise env ty = + Scheme (Set.toList (freeTypeVars ty Set.\\ freeTypeVars env)) ty + +instantiate :: Scheme -> TM TypeT +instantiate (Scheme bnds ty) = do + vars <- traverse (const genTyVar) bnds + let theta = Subst (Map.fromList (zip bnds vars)) + return (theta >>! ty) + +data UnifyContext = UnifyContext SourceRange TypeT TypeT + +unify :: SourceRange -> TypeT -> TypeT -> TM Subst +unify sr t1 t2 = unify' (UnifyContext sr t1 t2) t1 t2 -runTM :: TM a -> Either TypeError a -runTM m = evalStateT m 1 +unify' :: UnifyContext -> TypeT -> TypeT -> TM Subst +unify' _ TIntT TIntT = return mempty +unify' ctx (TFunT t1 t2) (TFunT u1 u2) = (<>) <$> unify' ctx t1 u1 <*> unify' ctx t2 u2 +unify' _ (TyVar var) ty = return (substVar var ty) +unify' _ ty (TyVar var) = return (substVar var ty) +unify' (UnifyContext sr t1 t2) _ _ = throwError (TypeError sr t1 t2) +convertType :: Type -> TypeT +convertType (TFun t1 t2) = TFunT (convertType t1) (convertType t2) +convertType TInt = TIntT -runPass :: Context -> Program -> Either TypeError ProgramT +infer :: Env -> Expr -> TM (Subst, ExprT) +infer env expr = case expr of + Lam _ [] body -> infer env body + Lam sr args@(_:_:_) body -> infer env (foldr (Lam sr . pure) body args) + Lam _ [(arg, _)] body -> do + argVar <- genTyVar + let augEnv = envAdd arg (Scheme [] argVar) env + (theta, body') <- infer augEnv body + let argType = theta >>! argVar + return (theta, LamT (TFunT argType (exprType body')) (Occ arg argType) body') + Call sr func arg -> do + (theta1, func') <- infer env func + (theta2, arg') <- infer (theta1 >>! env) arg + resVar <- genTyVar + theta3 <- unify sr (theta2 >>! exprType func') (TFunT (exprType arg') resVar) + return (theta3 <> theta2 <> theta1 + ,CallT (theta3 >>! resVar) + ((theta3 <> theta2) >>! func') -- TODO: quadratic complexity + (theta3 >>! arg')) -- TODO: quadratic complexity + Int _ val -> return (mempty, IntT val) + Var sr name + | Just sty <- envFind name env -> do + ty <- instantiate sty + return (mempty, VarT (Occ name ty)) + | otherwise -> + throwError (RefError sr name) + Annot sr subex ty -> do + (theta1, subex') <- infer env subex + theta2 <- unify sr (exprType subex') (convertType ty) + return (theta2 <> theta1, theta2 >>! subex') -- TODO: quadratic complexity + + +runPass :: Context -> Program -> Either TCError ProgramT runPass _ prog = runTM (typeCheck prog) typeCheck :: Program -> TM ProgramT -typeCheck (Program decls) = ProgramT <$> mapM typeCheckDL decls - -typeCheckDL :: Decl -> TM DeclT -typeCheckDL (Def def) = DefT <$> typeCheckD def - -typeCheckD :: Def -> TM DefT -typeCheckD (Function mt (fname, fnameR) args body) = do - (body', _) <- typeCheckE body - return (FunctionT (exprType body') fname (map fst args) body') - -typeCheckE :: Expr -> TM (ExprT, Mapping) -typeCheckE (Call sr func arg) = do - (func', m1) <- typeCheckE func - (arg', m2) <- typeCheckE arg - m <- combine m1 m2 - - let functype = exprType func' - argtype = exprType arg' - tvar <- genTyVar - apply <- unify (range func) functype (TFunT tvar argtype) - let restype = TFunT (apply tvar) (apply argtype) - func'' = down apply func' - arg'' = down apply arg' - - return (CallT restype func'' arg'') -typeCheckE (Int _ val) = return (IntT val, mempty) -typeCheckE (Var _ name) = VarT . Occ name <$> genTyVar - --- For each variable, its inferred type and the position of its first --- occurrence in a program fragment. -type Mapping = Map.Map Name (TypeT, SourceRange) - -combine :: Mapping -> Mapping -> TM Mapping -combine mp1 mp2 = do - let leftmap = Map.filterWithKey (\name _ -> not (Map.member name mp2)) mp1 - rightmap = Map.filterWithKey (\name _ -> not (Map.member name mp1)) mp2 - overlap = Map.keys mp1 `intersect` Map.keys mp2 - combine1 name (t1, sr1) (t2, sr2) - | t1 == t2 = Right (t1, sr1) - | otherwise = Left (TypeError sr2 t2 t1) - midpairs <- sequence [combine1 name (mp1 Map.! name) (mp2 Map.! name) - | name <- overlap] - return (Map.unions [leftmap, rightmap, Map.fromList midpairs]) - -unify :: SourceRange -> TypeT -> TypeT -> TM (TypeT -> TypeT) -unify = undefined +typeCheck (Program decls) = + let defs = [(name, ty) + | Def (Function (Just ty) (name, _) _ _) <- decls] + env = foldl (\env' (name, ty) -> envAdd name (generalise env' (convertType ty)) env') + emptyEnv defs + in ProgramT <$> mapM (typeCheckDef env . (\(Def def) -> def)) decls + +typeCheckDef :: Env -> Def -> TM DefT +typeCheckDef env (Function mannot (name, sr) args@(_:_) body) = + typeCheckDef env (Function mannot (name, sr) [] (Lam sr args body)) +typeCheckDef env (Function (Just annot) (name, sr) [] body) = + typeCheckDef env (Function Nothing (name, sr) [] (Annot sr body annot)) +typeCheckDef env (Function Nothing (name, _) [] body) = do + (_, body') <- infer env body + return (DefT name body') diff --git a/utils/CC/IdSupply.hs b/utils/CC/IdSupply.hs deleted file mode 100644 index 234f6cc..0000000 --- a/utils/CC/IdSupply.hs +++ /dev/null @@ -1,29 +0,0 @@ -module CC.IdSupply(IdSupply, runIdSupply, genId) where - -import Control.Monad.Trans - - -data IdSupply a = IdSupply (Int -> (Int, a)) - -instance Functor IdSupply where - fmap f (IdSupply act) = IdSupply (fmap f . act) - -instance Applicative IdSupply where - pure x = IdSupply (\idval -> (idval, x)) - IdSupply f <*> IdSupply x = - IdSupply (\idval -> let (idval', f') = f idval - in f' <$> x idval') - -instance Monad IdSupply where - IdSupply x >>= f = - IdSupply (\idval -> let (idval', x') = x idval - IdSupply res = f x' - in res idval') - -instance MonadTrans - -runIdSupply :: Int -> IdSupply a -> a -runIdSupply startid (IdSupply f) = snd (f startid) - -genId :: IdSupply Int -genId = IdSupply (\idval -> (idval + 1, idval)) |