{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeSynonymInstances #-} module CC.Typecheck(runPass) where import Control.Monad.State.Strict import Control.Monad.Except import qualified Data.Map.Strict as Map import Data.Map.Strict (Map) import Data.Maybe (fromMaybe, catMaybes) import qualified Data.Set as Set import Data.Set (Set) import Debug.Trace import qualified CC.AST.Source as S import qualified CC.AST.Typed as T import CC.Context import CC.Types import CC.Typecheck.Typedefs import CC.Typecheck.Types -- Inspiration: https://github.com/kritzcreek/fby19 data Env = Env (Map Name T.TypeScheme) -- Definitions in scope (Map Name T.TypeDef) -- Type definitions (Map Name S.AliasDef) -- Type aliases deriving (Show) newtype Subst = Subst (Map Int T.Type) class FreeTypeVars a where -- Free instantiable type variables freeInstTypeVars :: a -> Set Int instance FreeTypeVars T.Type where freeInstTypeVars (T.TFun t1 t2) = freeInstTypeVars t1 <> freeInstTypeVars t2 freeInstTypeVars T.TInt = mempty freeInstTypeVars (T.TTup ts) = Set.unions (map freeInstTypeVars ts) freeInstTypeVars (T.TNamed _ ts) = Set.unions (map freeInstTypeVars ts) freeInstTypeVars (T.TUnion ts) = Set.unions (map freeInstTypeVars (Set.toList ts)) freeInstTypeVars (T.TyVar T.Instantiable var) = Set.singleton var freeInstTypeVars (T.TyVar T.Rigid _) = mempty instance FreeTypeVars T.TypeScheme where freeInstTypeVars (T.TypeScheme bnds ty) = foldr Set.delete (freeInstTypeVars ty) bnds instance FreeTypeVars Env where freeInstTypeVars (Env mp _ _) = foldMap freeInstTypeVars (Map.elems mp) infixr >>! class Substitute a where (>>!) :: Subst -> a -> a instance Substitute T.Type where theta@(Subst mp) >>! ty = case ty of T.TFun t1 t2 -> T.TFun (theta >>! t1) (theta >>! t2) T.TInt -> T.TInt T.TTup ts -> T.TTup (map (theta >>!) ts) T.TNamed n ts -> T.TNamed n (map (theta >>!) ts) T.TUnion ts -> T.TUnion (Set.map (theta >>!) ts) T.TyVar T.Instantiable i -> fromMaybe ty (Map.lookup i mp) T.TyVar T.Rigid i | i `Map.member` mp -> error "Attempt to substitute a rigid type variable" | otherwise -> ty instance Substitute T.TypeScheme where Subst mp >>! T.TypeScheme bnds ty = T.TypeScheme bnds (Subst (foldr Map.delete mp bnds) >>! ty) instance Substitute Env where theta >>! Env mp tdefs aliases = Env (Map.map (theta >>!) mp) tdefs aliases -- TODO: make this instance unnecessary instance Substitute T.Expr where theta >>! T.Lam ty (T.Occ name ty2) body = T.Lam (theta >>! ty) (T.Occ name (theta >>! ty2)) (theta >>! body) theta >>! T.Let (T.Occ name ty) rhs body = T.Let (T.Occ name (theta >>! ty)) (theta >>! rhs) (theta >>! body) theta >>! T.Call ty e1 e2 = T.Call (theta >>! ty) (theta >>! e1) (theta >>! e2) _ >>! expr@(T.Int _) = expr theta >>! T.Tup es = T.Tup (map (theta >>!) es) theta >>! T.Var (T.Occ name ty) = T.Var (T.Occ name (theta >>! ty)) theta >>! T.Constr ty n = T.Constr (theta >>! ty) n instance Semigroup Subst where s2@(Subst m2) <> Subst m1 = Subst (Map.union (Map.map (s2 >>!) m1) m2) instance Monoid Subst where mempty = Subst mempty emptyEnv :: Env emptyEnv = Env mempty mempty mempty envAddDef :: Name -> T.TypeScheme -> Env -> Env envAddDef name sty (Env mp tmp aliases) | name `Map.member` mp = error "envAddDef on name already in environment" | otherwise = Env (Map.insert name sty mp) tmp aliases envFindDef :: Name -> Env -> Maybe T.TypeScheme envFindDef name (Env mp _ _) = Map.lookup name mp envAddTypes :: Map Name T.TypeDef -> Env -> Env envAddTypes l (Env mp tdefs aliases) = let combined = l <> tdefs in if Map.size combined == Map.size l + Map.size tdefs then Env mp combined aliases else error "envAddTypes on duplicate type names" envFindType :: Name -> Env -> Maybe T.TypeDef envFindType name (Env _ tdefs _) = Map.lookup name tdefs envAddAliases :: Map Name S.AliasDef -> Env -> Env envAddAliases l (Env mp tdefs aliases) = let combined = l <> aliases in if Map.size combined == Map.size l + Map.size aliases then Env mp tdefs combined else error "envAddAliaes on duplicate type names" envAliases :: Env -> Map Name S.AliasDef envAliases (Env _ _ aliases) = aliases substVar :: Int -> T.Type -> Subst substVar var ty = Subst (Map.singleton var ty) freshenScheme :: T.TypeScheme -> TM T.TypeScheme freshenScheme (T.TypeScheme bnds ty) = do vars <- traverse (const genId) bnds let theta = Subst (Map.fromList (zip bnds (map (T.TyVar T.Instantiable) vars))) return (T.TypeScheme vars (theta >>! ty)) generalise :: Env -> T.Type -> T.TypeScheme generalise env ty = T.TypeScheme (Set.toList (freeInstTypeVars ty Set.\\ freeInstTypeVars env)) ty instantiate :: T.TypeScheme -> TM T.Type instantiate scheme = (\(T.TypeScheme _ ty') -> ty') <$> freshenScheme scheme replaceRigid :: T.Type -> T.Type replaceRigid (T.TFun t1 t2) = T.TFun (replaceRigid t1) (replaceRigid t2) replaceRigid T.TInt = T.TInt replaceRigid (T.TTup ts) = T.TTup (map (replaceRigid) ts) replaceRigid (T.TNamed n ts) = T.TNamed n (map replaceRigid ts) replaceRigid (T.TUnion ts) = T.TUnion (Set.map replaceRigid ts) replaceRigid (T.TyVar _ v) = T.TyVar T.Rigid v checkType :: Env -> SourceRange -> T.Type -> TM () checkType env sr (T.TFun t1 t2) = checkType env sr t1 >> checkType env sr t2 checkType _ _ T.TInt = return () checkType env sr (T.TTup ts) = mapM_ (checkType env sr) ts checkType env sr (T.TNamed n ts) = do mapM_ (checkType env sr) ts case envFindType n env of Just (T.TypeDef _ args _) | length ts == length args -> return () | otherwise -> throwError (TypeArityError sr n (length args) (length ts)) Nothing -> throwError (RefError sr n) checkType env sr (T.TUnion ts) = mapM_ (checkType env sr) (Set.toList ts) checkType _ _ (T.TyVar _ _) = return () data UnifyContext = UnifyContext SourceRange T.Type T.Type -- t1 = got type: what did we infer using existing information -- t2 = wanted type: what should the type equal due to an annotation or language usage -- Unions are only weakened towards t2: {a} U {a,b} works, but {a,b} U {a} is an error. unify :: SourceRange -> T.Type -> T.Type -> TM Subst unify sr t1 t2 = unify' (UnifyContext sr t1 t2) t1 t2 unify' :: UnifyContext -> T.Type -> T.Type -> TM Subst unify' _ T.TInt T.TInt = return mempty unify' ctx (T.TFun t1 t2) (T.TFun u1 u2) = -- This one is subtle: function arguments are contravariant, so we swap -- unification direction here. (<>) <$> unify' ctx t2 u2 <*> unify' ctx u1 t1 unify' ctx (T.TTup ts) (T.TTup us) | length ts == length us = mconcat <$> zipWithM (unify' ctx) ts us unify' _ (T.TyVar T.Instantiable var) ty = return (substVar var ty) unify' _ ty (T.TyVar T.Instantiable var) = return (substVar var ty) unify' ctx (T.TNamed n1 ts) (T.TNamed n2 us) | n1 == n2, length ts == length us = mconcat <$> zipWithM (unify' ctx) ts us unify' ctx (T.TUnion topts) (T.TUnion topus) = -- TODO: this is quadratic in the right union size. I'm not sure whether -- this is avoidable, but it can probably be improved by partitioning on -- the name of TNamed's. mconcat . snd <$> mapAccumLM (\us ty -> matchup ty us) topus (Set.toList topts) where -- If a match is found, returns the substitution and the rest of the RHS -- types; else, throws an error matchup :: T.Type -> Set T.Type -> TM (Set T.Type, Subst) matchup ty ts = do let splits = [(item, uncurry (<>) (Set.split item ts)) | item <- Set.toList ts] results <- forM splits $ \(item, rest) -> catchError ((Just . (rest,)) <$> unify' ctx ty item) (const (return Nothing)) case catMaybes results of [] -> let UnifyContext sr topt1 topt2 = ctx in throwError (UnifyError sr topt1 topt2 ty (T.TUnion ts) (Just URNotInUnion)) [result] -> return result _ -> let UnifyContext sr topt1 topt2 = ctx in throwError (UnifyError sr topt1 topt2 ty (T.TUnion ts) (Just URAmbiguousWeakening)) unify' ctx ty (T.TUnion us) = unify' ctx (T.TUnion (Set.singleton ty)) (T.TUnion us) unify' ctx (T.TUnion ts) ty | Set.size ts == 0 = return mempty | Set.size ts == 1 = unify' ctx (Set.findMin ts) ty unify' (UnifyContext sr t1 t2) u1 u2 = throwError (UnifyError sr t1 t2 u1 u2 Nothing) infer :: Env -> S.Expr -> TM (Subst, T.Expr) infer env expr = case expr of S.Lam _ [] body -> infer env body S.Lam sr args@(_:_:_) body -> infer env (foldr (S.Lam sr . pure) body args) S.Lam _ [(arg, _)] body -> do argVar <- genTyVar let augEnv = envAddDef arg (T.TypeScheme [] argVar) env (theta, body') <- infer augEnv body let argType = theta >>! argVar return (theta, T.Lam (T.TFun argType (T.exprType body')) (T.Occ arg argType) body') S.Let _ (name, _) rhs body -> do (theta1, rhs') <- infer env rhs let varType = T.exprType rhs' let augEnv = envAddDef name (T.TypeScheme [] varType) env (theta2, body') <- infer augEnv body return (theta2 <> theta1, T.Let (T.Occ name varType) rhs' body') S.Call sr func arg -> do (theta1, func') <- infer env func (theta2, arg') <- infer (theta1 >>! env) arg resVar <- genTyVar theta3 <- unify sr (theta2 >>! T.exprType func') (T.TFun (T.exprType arg') resVar) return (theta3 <> theta2 <> theta1 ,T.Call (theta3 >>! resVar) ((theta3 <> theta2) >>! func') -- TODO: quadratic complexity (theta3 >>! arg')) -- TODO: quadratic complexity S.Int _ val -> return (mempty, T.Int val) S.Tup _ es -> fmap T.Tup <$> inferList env es S.Var sr name | Just sty <- envFindDef name env -> do ty <- instantiate sty return (mempty, T.Var (T.Occ name ty)) | otherwise -> throwError (RefError sr name) S.Constr sr name -> case envFindType name env of Just (T.TypeDef typname params typ) -> do T.TypeScheme params' typ' <- freshenScheme (T.TypeScheme params typ) let restyp = T.TNamed typname (map (T.TyVar T.Instantiable) params') return (mempty, T.Constr (T.TFun typ' restyp) name) _ -> throwError (RefError sr name) S.Annot sr subex ty -> do (theta1, subex') <- infer env subex ty' <- convertType (envAliases env) sr ty checkType env sr ty' -- Make sure the type of the subexpression matches the type with rigid -- variables, then make it instantiable variables instead for the rest -- of the code. void $ unify sr (T.exprType subex') (replaceRigid ty') theta2 <- unify sr (T.exprType subex') ty' return (theta2 <> theta1, theta2 >>! subex') -- TODO: quadratic complexity inferList :: Env -> [S.Expr] -> TM (Subst, [T.Expr]) inferList _ [] = return (mempty, []) inferList env (expr : exprs) = do (theta, expr') <- infer env expr (theta', res) <- inferList (theta >>! env) exprs return (theta <> theta', expr' : res) runPass :: Context -> S.Program -> Either TCError T.Program runPass (Context _ (Builtins builtins _)) prog = let env = Env (Map.map (generalise emptyEnv) builtins) mempty mempty in runTM (typeCheck env prog) typeCheck :: Env -> S.Program -> TM T.Program typeCheck startEnv (S.Program decls) = do traceM (show decls) let aliasdefs = [(n, def) | S.DeclAlias def@(S.AliasDef (n, _) _ _) <- decls] env1 = envAddAliases (Map.fromList aliasdefs) startEnv typedefs' <- checkTypedefs (envAliases env1) [def | S.DeclType def <- decls] let typedefsMap = Map.fromList [(n, def) | def@(T.TypeDef n _ _) <- typedefs'] let funcdefs = [def | S.DeclFunc def <- decls] typedfuncs <- sequence [(name,) <$> convertType (envAliases env1) sr ty | S.FuncDef (Just ty) (name, sr) _ _ <- funcdefs] let env2 = envAddTypes typedefsMap env1 traceM (show typedefsMap) let env = foldl (\env' (name, ty) -> envAddDef name (generalise env' ty) env') env2 typedfuncs traceM (show env) funcdefs' <- mapM (typeCheckFunc env) funcdefs return (T.Program funcdefs' typedefsMap) typeCheckFunc :: Env -> S.FuncDef -> TM T.Def typeCheckFunc env (S.FuncDef mannot (name, sr) args@(_:_) body) = typeCheckFunc env (S.FuncDef mannot (name, sr) [] (S.Lam sr args body)) typeCheckFunc env (S.FuncDef (Just annot) (name, sr) [] body) = typeCheckFunc env (S.FuncDef Nothing (name, sr) [] (S.Annot sr body annot)) typeCheckFunc env (S.FuncDef Nothing (name, _) [] body) = do (_, body') <- infer env body return (T.Def name body') mapAccumLM :: Monad m => (a -> b -> m (a, c)) -> a -> [b] -> m (a, [c]) mapAccumLM _ start [] = return (start, []) mapAccumLM f start (x:xs) = f start x >>= \(next, y) -> fmap (y :) <$> mapAccumLM f next xs