{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE TypeSynonymInstances #-} module CC.Typecheck(runPass) where import Control.Monad.State.Strict import Control.Monad.Except import qualified Data.Map.Strict as Map import Data.Map.Strict (Map) import Data.Maybe import qualified Data.Set as Set import Data.Set (Set) import CC.AST.Source import CC.AST.Typed import CC.Context import CC.Pretty -- Inspiration: https://github.com/kritzcreek/fby19 data TCError = TypeError SourceRange TypeT TypeT | RefError SourceRange Name deriving (Show) instance Pretty TCError where pretty (TypeError sr real expect) = "Type error: Expression at " ++ pretty sr ++ " has type " ++ pretty real ++ ", but should have type " ++ pretty expect pretty (RefError sr name) = "Reference error: Variable '" ++ name ++ "' out of scope at " ++ pretty sr type TM a = ExceptT TCError (State Int) a genId :: TM Int genId = state (\idval -> (idval, idval + 1)) genTyVar :: TM TypeT genTyVar = TyVar <$> genId runTM :: TM a -> Either TCError a runTM m = evalState (runExceptT m) 1 newtype Env = Env (Map Name TypeSchemeT) newtype Subst = Subst (Map Int TypeT) class FreeTypeVars a where freeTypeVars :: a -> Set Int instance FreeTypeVars TypeT where freeTypeVars (TFunT t1 t2) = freeTypeVars t1 <> freeTypeVars t2 freeTypeVars TIntT = mempty freeTypeVars (TTupT ts) = Set.unions (map freeTypeVars ts) freeTypeVars (TyVar var) = Set.singleton var instance FreeTypeVars TypeSchemeT where freeTypeVars (TypeSchemeT bnds ty) = foldr Set.delete (freeTypeVars ty) bnds instance FreeTypeVars Env where freeTypeVars (Env mp) = foldMap freeTypeVars (Map.elems mp) infixr >>! class Substitute a where (>>!) :: Subst -> a -> a instance Substitute TypeT where theta@(Subst mp) >>! ty = case ty of TFunT t1 t2 -> TFunT (theta >>! t1) (theta >>! t2) TIntT -> TIntT TTupT ts -> TTupT (map (theta >>!) ts) TyVar i -> fromMaybe ty (Map.lookup i mp) instance Substitute TypeSchemeT where Subst mp >>! TypeSchemeT bnds ty = TypeSchemeT bnds (Subst (foldr Map.delete mp bnds) >>! ty) instance Substitute Env where theta >>! Env mp = Env (Map.map (theta >>!) mp) -- TODO: make this instance unnecessary instance Substitute ExprT where theta >>! LamT ty (Occ name ty2) body = LamT (theta >>! ty) (Occ name (theta >>! ty2)) (theta >>! body) theta >>! CallT ty e1 e2 = CallT (theta >>! ty) (theta >>! e1) (theta >>! e2) _ >>! expr@(IntT _) = expr theta >>! TupT es = TupT (map (theta >>!) es) theta >>! VarT (Occ name ty) = VarT (Occ name (theta >>! ty)) instance Semigroup Subst where s2@(Subst m2) <> Subst m1 = Subst (Map.union (Map.map (s2 >>!) m1) m2) instance Monoid Subst where mempty = Subst mempty emptyEnv :: Env emptyEnv = Env mempty envAdd :: Name -> TypeSchemeT -> Env -> Env envAdd name sty (Env mp) = Env (Map.insert name sty mp) envFind :: Name -> Env -> Maybe TypeSchemeT envFind name (Env mp) = Map.lookup name mp substVar :: Int -> TypeT -> Subst substVar var ty = Subst (Map.singleton var ty) generalise :: Env -> TypeT -> TypeSchemeT generalise env ty = TypeSchemeT (Set.toList (freeTypeVars ty Set.\\ freeTypeVars env)) ty instantiate :: TypeSchemeT -> TM TypeT instantiate (TypeSchemeT bnds ty) = do vars <- traverse (const genTyVar) bnds let theta = Subst (Map.fromList (zip bnds vars)) return (theta >>! ty) data UnifyContext = UnifyContext SourceRange TypeT TypeT unify :: SourceRange -> TypeT -> TypeT -> TM Subst unify sr t1 t2 = unify' (UnifyContext sr t1 t2) t1 t2 unify' :: UnifyContext -> TypeT -> TypeT -> TM Subst unify' _ TIntT TIntT = return mempty unify' ctx (TFunT t1 t2) (TFunT u1 u2) = (<>) <$> unify' ctx t1 u1 <*> unify' ctx t2 u2 unify' ctx (TTupT ts) (TTupT us) | length ts == length us = mconcat <$> zipWithM (unify' ctx) ts us unify' _ (TyVar var) ty = return (substVar var ty) unify' _ ty (TyVar var) = return (substVar var ty) unify' (UnifyContext sr t1 t2) _ _ = throwError (TypeError sr t1 t2) convertType :: Type -> TypeT convertType (TFun t1 t2) = TFunT (convertType t1) (convertType t2) convertType TInt = TIntT convertType (TTup ts) = TTupT (map convertType ts) infer :: Env -> Expr -> TM (Subst, ExprT) infer env expr = case expr of Lam _ [] body -> infer env body Lam sr args@(_:_:_) body -> infer env (foldr (Lam sr . pure) body args) Lam _ [(arg, _)] body -> do argVar <- genTyVar let augEnv = envAdd arg (TypeSchemeT [] argVar) env (theta, body') <- infer augEnv body let argType = theta >>! argVar return (theta, LamT (TFunT argType (exprType body')) (Occ arg argType) body') Call sr func arg -> do (theta1, func') <- infer env func (theta2, arg') <- infer (theta1 >>! env) arg resVar <- genTyVar theta3 <- unify sr (theta2 >>! exprType func') (TFunT (exprType arg') resVar) return (theta3 <> theta2 <> theta1 ,CallT (theta3 >>! resVar) ((theta3 <> theta2) >>! func') -- TODO: quadratic complexity (theta3 >>! arg')) -- TODO: quadratic complexity Int _ val -> return (mempty, IntT val) Tup _ es -> fmap TupT <$> inferList env es Var sr name | Just sty <- envFind name env -> do ty <- instantiate sty return (mempty, VarT (Occ name ty)) | otherwise -> throwError (RefError sr name) Annot sr subex ty -> do (theta1, subex') <- infer env subex theta2 <- unify sr (exprType subex') (convertType ty) return (theta2 <> theta1, theta2 >>! subex') -- TODO: quadratic complexity inferList :: Env -> [Expr] -> TM (Subst, [ExprT]) inferList _ [] = return (mempty, []) inferList env (expr : exprs) = do (theta, expr') <- infer env expr (theta', res) <- inferList (theta >>! env) exprs return (theta <> theta', expr' : res) runPass :: Context -> Program -> Either TCError ProgramT runPass (Context _ (Builtins builtins)) prog = let env = Env (Map.fromList [(name, generalise emptyEnv ty) | (name, ty) <- builtins]) in runTM (typeCheck env prog) typeCheck :: Env -> Program -> TM ProgramT typeCheck startEnv (Program decls) = let defs = [(name, ty) | Def (Function (Just ty) (name, _) _ _) <- decls] env = foldl (\env' (name, ty) -> envAdd name (generalise env' (convertType ty)) env') startEnv defs in ProgramT <$> mapM (typeCheckDef env . (\(Def def) -> def)) decls typeCheckDef :: Env -> Def -> TM DefT typeCheckDef env (Function mannot (name, sr) args@(_:_) body) = typeCheckDef env (Function mannot (name, sr) [] (Lam sr args body)) typeCheckDef env (Function (Just annot) (name, sr) [] body) = typeCheckDef env (Function Nothing (name, sr) [] (Annot sr body annot)) typeCheckDef env (Function Nothing (name, _) [] body) = do (_, body') <- infer env body return (DefT name body')