{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE TypeSynonymInstances #-} module CC.Typecheck(runPass) where import Control.Monad.State.Strict import Control.Monad.Except import qualified Data.Map.Strict as Map import Data.Map.Strict (Map) import Data.Maybe import qualified Data.Set as Set import Data.Set (Set) import CC.Pretty import CC.Source import CC.Typed -- Inspiration: https://github.com/kritzcreek/fby19 data TCError = TypeError SourceRange TypeT TypeT | RefError SourceRange Name deriving (Show) instance Pretty TCError where pretty (TypeError sr real expect) = "Type error: Expression at " ++ pretty sr ++ " has type " ++ pretty real ++ ", but should have type " ++ pretty expect pretty (RefError sr name) = "Reference error: Variable '" ++ name ++ "' out of scope at " ++ pretty sr type TM a = ExceptT TCError (State Int) a genId :: TM Int genId = state (\idval -> (idval, idval + 1)) genTyVar :: TM TypeT genTyVar = TyVar <$> genId runTM :: TM a -> Either TCError a runTM m = evalState (runExceptT m) 1 data Scheme = Scheme [Int] TypeT newtype Env = Env (Map Name Scheme) newtype Subst = Subst (Map Int TypeT) class FreeTypeVars a where freeTypeVars :: a -> Set Int instance FreeTypeVars TypeT where freeTypeVars (TFunT t1 t2) = freeTypeVars t1 <> freeTypeVars t2 freeTypeVars TIntT = mempty freeTypeVars (TyVar var) = Set.singleton var instance FreeTypeVars Scheme where freeTypeVars (Scheme bnds ty) = foldr Set.delete (freeTypeVars ty) bnds instance FreeTypeVars Env where freeTypeVars (Env mp) = foldMap freeTypeVars (Map.elems mp) infixr >>! class Substitute a where (>>!) :: Subst -> a -> a instance Substitute TypeT where theta@(Subst mp) >>! ty = case ty of TFunT t1 t2 -> TFunT (theta >>! t1) (theta >>! t2) TIntT -> TIntT TyVar i -> fromMaybe ty (Map.lookup i mp) instance Substitute Scheme where Subst mp >>! Scheme bnds ty = Scheme bnds (Subst (foldr Map.delete mp bnds) >>! ty) instance Substitute Env where theta >>! Env mp = Env (Map.map (theta >>!) mp) -- TODO: make this instance unnecessary instance Substitute ExprT where theta >>! LamT ty (Occ name ty2) body = LamT (theta >>! ty) (Occ name (theta >>! ty2)) (theta >>! body) theta >>! CallT ty e1 e2 = CallT (theta >>! ty) (theta >>! e1) (theta >>! e2) _ >>! expr@(IntT _) = expr theta >>! VarT (Occ name ty) = VarT (Occ name (theta >>! ty)) instance Semigroup Subst where s2@(Subst m2) <> Subst m1 = Subst (Map.union (Map.map (s2 >>!) m1) m2) instance Monoid Subst where mempty = Subst mempty emptyEnv :: Env emptyEnv = Env mempty envAdd :: Name -> Scheme -> Env -> Env envAdd name sty (Env mp) = Env (Map.insert name sty mp) envFind :: Name -> Env -> Maybe Scheme envFind name (Env mp) = Map.lookup name mp substVar :: Int -> TypeT -> Subst substVar var ty = Subst (Map.singleton var ty) generalise :: Env -> TypeT -> Scheme generalise env ty = Scheme (Set.toList (freeTypeVars ty Set.\\ freeTypeVars env)) ty instantiate :: Scheme -> TM TypeT instantiate (Scheme bnds ty) = do vars <- traverse (const genTyVar) bnds let theta = Subst (Map.fromList (zip bnds vars)) return (theta >>! ty) data UnifyContext = UnifyContext SourceRange TypeT TypeT unify :: SourceRange -> TypeT -> TypeT -> TM Subst unify sr t1 t2 = unify' (UnifyContext sr t1 t2) t1 t2 unify' :: UnifyContext -> TypeT -> TypeT -> TM Subst unify' _ TIntT TIntT = return mempty unify' ctx (TFunT t1 t2) (TFunT u1 u2) = (<>) <$> unify' ctx t1 u1 <*> unify' ctx t2 u2 unify' _ (TyVar var) ty = return (substVar var ty) unify' _ ty (TyVar var) = return (substVar var ty) unify' (UnifyContext sr t1 t2) _ _ = throwError (TypeError sr t1 t2) convertType :: Type -> TypeT convertType (TFun t1 t2) = TFunT (convertType t1) (convertType t2) convertType TInt = TIntT infer :: Env -> Expr -> TM (Subst, ExprT) infer env expr = case expr of Lam _ [] body -> infer env body Lam sr args@(_:_:_) body -> infer env (foldr (Lam sr . pure) body args) Lam _ [(arg, _)] body -> do argVar <- genTyVar let augEnv = envAdd arg (Scheme [] argVar) env (theta, body') <- infer augEnv body let argType = theta >>! argVar return (theta, LamT (TFunT argType (exprType body')) (Occ arg argType) body') Call sr func arg -> do (theta1, func') <- infer env func (theta2, arg') <- infer (theta1 >>! env) arg resVar <- genTyVar theta3 <- unify sr (theta2 >>! exprType func') (TFunT (exprType arg') resVar) return (theta3 <> theta2 <> theta1 ,CallT (theta3 >>! resVar) ((theta3 <> theta2) >>! func') -- TODO: quadratic complexity (theta3 >>! arg')) -- TODO: quadratic complexity Int _ val -> return (mempty, IntT val) Var sr name | Just sty <- envFind name env -> do ty <- instantiate sty return (mempty, VarT (Occ name ty)) | otherwise -> throwError (RefError sr name) Annot sr subex ty -> do (theta1, subex') <- infer env subex theta2 <- unify sr (exprType subex') (convertType ty) return (theta2 <> theta1, theta2 >>! subex') -- TODO: quadratic complexity runPass :: Context -> Program -> Either TCError ProgramT runPass _ prog = runTM (typeCheck prog) typeCheck :: Program -> TM ProgramT typeCheck (Program decls) = let defs = [(name, ty) | Def (Function (Just ty) (name, _) _ _) <- decls] env = foldl (\env' (name, ty) -> envAdd name (generalise env' (convertType ty)) env') emptyEnv defs in ProgramT <$> mapM (typeCheckDef env . (\(Def def) -> def)) decls typeCheckDef :: Env -> Def -> TM DefT typeCheckDef env (Function mannot (name, sr) args@(_:_) body) = typeCheckDef env (Function mannot (name, sr) [] (Lam sr args body)) typeCheckDef env (Function (Just annot) (name, sr) [] body) = typeCheckDef env (Function Nothing (name, sr) [] (Annot sr body annot)) typeCheckDef env (Function Nothing (name, _) [] body) = do (_, body') <- infer env body return (DefT name body')