{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE TypeSynonymInstances #-} module CC.Typecheck(runPass) where import Control.Monad.State.Strict import Control.Monad.Except import qualified Data.Map.Strict as Map import Data.Map.Strict (Map) import Data.Maybe import qualified Data.Set as Set import Data.Set (Set) import qualified CC.AST.Source as S import qualified CC.AST.Typed as T import CC.Context import CC.Pretty import CC.Types -- Inspiration: https://github.com/kritzcreek/fby19 data TCError = TypeError SourceRange T.Type T.Type | RefError SourceRange Name deriving (Show) instance Pretty TCError where pretty (TypeError sr real expect) = "Type error: Expression at " ++ pretty sr ++ " has type " ++ pretty real ++ ", but should have type " ++ pretty expect pretty (RefError sr name) = "Reference error: Variable '" ++ name ++ "' out of scope at " ++ pretty sr type TM a = ExceptT TCError (State Int) a genId :: TM Int genId = state (\idval -> (idval, idval + 1)) genTyVar :: TM T.Type genTyVar = T.TyVar <$> genId runTM :: TM a -> Either TCError a runTM m = evalState (runExceptT m) 1 newtype Env = Env (Map Name T.TypeScheme) newtype Subst = Subst (Map Int T.Type) class FreeTypeVars a where freeTypeVars :: a -> Set Int instance FreeTypeVars T.Type where freeTypeVars (T.TFun t1 t2) = freeTypeVars t1 <> freeTypeVars t2 freeTypeVars T.TInt = mempty freeTypeVars (T.TTup ts) = Set.unions (map freeTypeVars ts) freeTypeVars (T.TyVar var) = Set.singleton var instance FreeTypeVars T.TypeScheme where freeTypeVars (T.TypeScheme bnds ty) = foldr Set.delete (freeTypeVars ty) bnds instance FreeTypeVars Env where freeTypeVars (Env mp) = foldMap freeTypeVars (Map.elems mp) infixr >>! class Substitute a where (>>!) :: Subst -> a -> a instance Substitute T.Type where theta@(Subst mp) >>! ty = case ty of T.TFun t1 t2 -> T.TFun (theta >>! t1) (theta >>! t2) T.TInt -> T.TInt T.TTup ts -> T.TTup (map (theta >>!) ts) T.TyVar i -> fromMaybe ty (Map.lookup i mp) instance Substitute T.TypeScheme where Subst mp >>! T.TypeScheme bnds ty = T.TypeScheme bnds (Subst (foldr Map.delete mp bnds) >>! ty) instance Substitute Env where theta >>! Env mp = Env (Map.map (theta >>!) mp) -- TODO: make this instance unnecessary instance Substitute T.Expr where theta >>! T.Lam ty (T.Occ name ty2) body = T.Lam (theta >>! ty) (T.Occ name (theta >>! ty2)) (theta >>! body) theta >>! T.Let (T.Occ name ty) rhs body = T.Let (T.Occ name (theta >>! ty)) (theta >>! rhs) (theta >>! body) theta >>! T.Call ty e1 e2 = T.Call (theta >>! ty) (theta >>! e1) (theta >>! e2) _ >>! expr@(T.Int _) = expr theta >>! T.Tup es = T.Tup (map (theta >>!) es) theta >>! T.Var (T.Occ name ty) = T.Var (T.Occ name (theta >>! ty)) instance Semigroup Subst where s2@(Subst m2) <> Subst m1 = Subst (Map.union (Map.map (s2 >>!) m1) m2) instance Monoid Subst where mempty = Subst mempty emptyEnv :: Env emptyEnv = Env mempty envAdd :: Name -> T.TypeScheme -> Env -> Env envAdd name sty (Env mp) = Env (Map.insert name sty mp) envFind :: Name -> Env -> Maybe T.TypeScheme envFind name (Env mp) = Map.lookup name mp substVar :: Int -> T.Type -> Subst substVar var ty = Subst (Map.singleton var ty) generalise :: Env -> T.Type -> T.TypeScheme generalise env ty = T.TypeScheme (Set.toList (freeTypeVars ty Set.\\ freeTypeVars env)) ty instantiate :: T.TypeScheme -> TM T.Type instantiate (T.TypeScheme bnds ty) = do vars <- traverse (const genTyVar) bnds let theta = Subst (Map.fromList (zip bnds vars)) return (theta >>! ty) data UnifyContext = UnifyContext SourceRange T.Type T.Type unify :: SourceRange -> T.Type -> T.Type -> TM Subst unify sr t1 t2 = unify' (UnifyContext sr t1 t2) t1 t2 unify' :: UnifyContext -> T.Type -> T.Type -> TM Subst unify' _ T.TInt T.TInt = return mempty unify' ctx (T.TFun t1 t2) (T.TFun u1 u2) = (<>) <$> unify' ctx t1 u1 <*> unify' ctx t2 u2 unify' ctx (T.TTup ts) (T.TTup us) | length ts == length us = mconcat <$> zipWithM (unify' ctx) ts us unify' _ (T.TyVar var) ty = return (substVar var ty) unify' _ ty (T.TyVar var) = return (substVar var ty) unify' (UnifyContext sr t1 t2) _ _ = throwError (TypeError sr t1 t2) convertType :: S.Type -> T.Type convertType (S.TFun t1 t2) = T.TFun (convertType t1) (convertType t2) convertType S.TInt = T.TInt convertType (S.TTup ts) = T.TTup (map convertType ts) infer :: Env -> S.Expr -> TM (Subst, T.Expr) infer env expr = case expr of S.Lam _ [] body -> infer env body S.Lam sr args@(_:_:_) body -> infer env (foldr (S.Lam sr . pure) body args) S.Lam _ [(arg, _)] body -> do argVar <- genTyVar let augEnv = envAdd arg (T.TypeScheme [] argVar) env (theta, body') <- infer augEnv body let argType = theta >>! argVar return (theta, T.Lam (T.TFun argType (T.exprType body')) (T.Occ arg argType) body') S.Let _ (name, _) rhs body -> do (theta1, rhs') <- infer env rhs let varType = T.exprType rhs' let augEnv = envAdd name (T.TypeScheme [] varType) env (theta2, body') <- infer augEnv body return (theta2 <> theta1, T.Let (T.Occ name varType) rhs' body') S.Call sr func arg -> do (theta1, func') <- infer env func (theta2, arg') <- infer (theta1 >>! env) arg resVar <- genTyVar theta3 <- unify sr (theta2 >>! T.exprType func') (T.TFun (T.exprType arg') resVar) return (theta3 <> theta2 <> theta1 ,T.Call (theta3 >>! resVar) ((theta3 <> theta2) >>! func') -- TODO: quadratic complexity (theta3 >>! arg')) -- TODO: quadratic complexity S.Int _ val -> return (mempty, T.Int val) S.Tup _ es -> fmap T.Tup <$> inferList env es S.Var sr name | Just sty <- envFind name env -> do ty <- instantiate sty return (mempty, T.Var (T.Occ name ty)) | otherwise -> throwError (RefError sr name) S.Annot sr subex ty -> do (theta1, subex') <- infer env subex theta2 <- unify sr (T.exprType subex') (convertType ty) return (theta2 <> theta1, theta2 >>! subex') -- TODO: quadratic complexity inferList :: Env -> [S.Expr] -> TM (Subst, [T.Expr]) inferList _ [] = return (mempty, []) inferList env (expr : exprs) = do (theta, expr') <- infer env expr (theta', res) <- inferList (theta >>! env) exprs return (theta <> theta', expr' : res) runPass :: Context -> S.Program -> Either TCError T.Program runPass (Context _ (Builtins builtins)) prog = let env = Env (Map.map (generalise emptyEnv) builtins) in runTM (typeCheck env prog) typeCheck :: Env -> S.Program -> TM T.Program typeCheck startEnv (S.Program decls) = let defs = [(name, ty) | S.Def (S.Function (Just ty) (name, _) _ _) <- decls] env = foldl (\env' (name, ty) -> envAdd name (generalise env' (convertType ty)) env') startEnv defs in T.Program <$> mapM (typeCheckDef env . (\(S.Def def) -> def)) decls typeCheckDef :: Env -> S.Def -> TM T.Def typeCheckDef env (S.Function mannot (name, sr) args@(_:_) body) = typeCheckDef env (S.Function mannot (name, sr) [] (S.Lam sr args body)) typeCheckDef env (S.Function (Just annot) (name, sr) [] body) = typeCheckDef env (S.Function Nothing (name, sr) [] (S.Annot sr body annot)) typeCheckDef env (S.Function Nothing (name, _) [] body) = do (_, body') <- infer env body return (T.Def name body')