{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeSynonymInstances #-} module CC.Typecheck(runPass) where import Control.Monad.State.Strict import Control.Monad.Except import qualified Data.Map.Strict as Map import Data.Map.Strict (Map) import Data.Maybe import qualified Data.Set as Set import Data.Set (Set) import Debug.Trace import qualified CC.AST.Source as S import qualified CC.AST.Typed as T import CC.Context import CC.Types import CC.Typecheck.Typedefs import CC.Typecheck.Types -- Inspiration: https://github.com/kritzcreek/fby19 data Env = Env (Map Name T.TypeScheme) -- Definitions in scope (Map Name T.TypeDef) -- Type definitions (Map Name S.AliasDef) -- Type aliases deriving (Show) newtype Subst = Subst (Map Int T.Type) class FreeTypeVars a where -- Free instantiable type variables freeInstTypeVars :: a -> Set Int instance FreeTypeVars T.Type where freeInstTypeVars (T.TFun t1 t2) = freeInstTypeVars t1 <> freeInstTypeVars t2 freeInstTypeVars T.TInt = mempty freeInstTypeVars (T.TTup ts) = Set.unions (map freeInstTypeVars ts) freeInstTypeVars (T.TNamed _ ts) = Set.unions (map freeInstTypeVars ts) freeInstTypeVars (T.TUnion ts) = Set.unions (map freeInstTypeVars (Set.toList ts)) freeInstTypeVars (T.TyVar T.Instantiable var) = Set.singleton var freeInstTypeVars (T.TyVar T.Rigid _) = mempty instance FreeTypeVars T.TypeScheme where freeInstTypeVars (T.TypeScheme bnds ty) = foldr Set.delete (freeInstTypeVars ty) bnds instance FreeTypeVars Env where freeInstTypeVars (Env mp _ _) = foldMap freeInstTypeVars (Map.elems mp) infixr >>! class Substitute a where (>>!) :: Subst -> a -> a instance Substitute T.Type where theta@(Subst mp) >>! ty = case ty of T.TFun t1 t2 -> T.TFun (theta >>! t1) (theta >>! t2) T.TInt -> T.TInt T.TTup ts -> T.TTup (map (theta >>!) ts) T.TNamed n ts -> T.TNamed n (map (theta >>!) ts) T.TUnion ts -> T.TUnion (Set.map (theta >>!) ts) T.TyVar T.Instantiable i -> fromMaybe ty (Map.lookup i mp) T.TyVar T.Rigid i | i `Map.member` mp -> error "Attempt to substitute a rigid type variable" | otherwise -> ty instance Substitute T.TypeScheme where Subst mp >>! T.TypeScheme bnds ty = T.TypeScheme bnds (Subst (foldr Map.delete mp bnds) >>! ty) instance Substitute Env where theta >>! Env mp tdefs aliases = Env (Map.map (theta >>!) mp) tdefs aliases -- TODO: make this instance unnecessary instance Substitute T.Expr where theta >>! T.Lam ty (T.Occ name ty2) body = T.Lam (theta >>! ty) (T.Occ name (theta >>! ty2)) (theta >>! body) theta >>! T.Let (T.Occ name ty) rhs body = T.Let (T.Occ name (theta >>! ty)) (theta >>! rhs) (theta >>! body) theta >>! T.Call ty e1 e2 = T.Call (theta >>! ty) (theta >>! e1) (theta >>! e2) _ >>! expr@(T.Int _) = expr theta >>! T.Tup es = T.Tup (map (theta >>!) es) theta >>! T.Var (T.Occ name ty) = T.Var (T.Occ name (theta >>! ty)) theta >>! T.Constr ty n = T.Constr (theta >>! ty) n instance Semigroup Subst where s2@(Subst m2) <> Subst m1 = Subst (Map.union (Map.map (s2 >>!) m1) m2) instance Monoid Subst where mempty = Subst mempty emptyEnv :: Env emptyEnv = Env mempty mempty mempty envAddDef :: Name -> T.TypeScheme -> Env -> Env envAddDef name sty (Env mp tmp aliases) | name `Map.member` mp = error "envAddDef on name already in environment" | otherwise = Env (Map.insert name sty mp) tmp aliases envFindDef :: Name -> Env -> Maybe T.TypeScheme envFindDef name (Env mp _ _) = Map.lookup name mp envAddTypes :: Map Name T.TypeDef -> Env -> Env envAddTypes l (Env mp tdefs aliases) = let combined = l <> tdefs in if Map.size combined == Map.size l + Map.size tdefs then Env mp combined aliases else error "envAddTypes on duplicate type names" envFindType :: Name -> Env -> Maybe T.TypeDef envFindType name (Env _ tdefs _) = Map.lookup name tdefs envAddAliases :: Map Name S.AliasDef -> Env -> Env envAddAliases l (Env mp tdefs aliases) = let combined = l <> aliases in if Map.size combined == Map.size l + Map.size aliases then Env mp tdefs combined else error "envAddAliaes on duplicate type names" envAliases :: Env -> Map Name S.AliasDef envAliases (Env _ _ aliases) = aliases substVar :: Int -> T.Type -> Subst substVar var ty = Subst (Map.singleton var ty) generalise :: Env -> T.Type -> T.TypeScheme generalise env ty = T.TypeScheme (Set.toList (freeInstTypeVars ty Set.\\ freeInstTypeVars env)) ty instantiate :: T.TypeScheme -> TM T.Type instantiate (T.TypeScheme bnds ty) = do vars <- traverse (const genTyVar) bnds let theta = Subst (Map.fromList (zip bnds vars)) return (theta >>! ty) freshenFrees :: Env -> T.Type -> TM T.Type freshenFrees env = instantiate . generalise env replaceRigid :: T.Type -> T.Type replaceRigid (T.TFun t1 t2) = T.TFun (replaceRigid t1) (replaceRigid t2) replaceRigid T.TInt = T.TInt replaceRigid (T.TTup ts) = T.TTup (map (replaceRigid) ts) replaceRigid (T.TNamed n ts) = T.TNamed n (map replaceRigid ts) replaceRigid (T.TUnion ts) = T.TUnion (Set.map replaceRigid ts) replaceRigid (T.TyVar _ v) = T.TyVar T.Rigid v data UnifyContext = UnifyContext SourceRange T.Type T.Type unify :: SourceRange -> T.Type -> T.Type -> TM Subst unify sr t1 t2 = unify' (UnifyContext sr t1 t2) t1 t2 unify' :: UnifyContext -> T.Type -> T.Type -> TM Subst unify' _ T.TInt T.TInt = return mempty unify' ctx (T.TFun t1 t2) (T.TFun u1 u2) = (<>) <$> unify' ctx t1 u1 <*> unify' ctx t2 u2 unify' ctx (T.TTup ts) (T.TTup us) | length ts == length us = mconcat <$> zipWithM (unify' ctx) ts us unify' _ (T.TyVar T.Instantiable var) ty = return (substVar var ty) unify' _ ty (T.TyVar T.Instantiable var) = return (substVar var ty) -- TODO: fix unify unify' (UnifyContext sr t1 t2) _ _ = throwError (TypeError sr t1 t2) infer :: Env -> S.Expr -> TM (Subst, T.Expr) infer env expr = case expr of S.Lam _ [] body -> infer env body S.Lam sr args@(_:_:_) body -> infer env (foldr (S.Lam sr . pure) body args) S.Lam _ [(arg, _)] body -> do argVar <- genTyVar let augEnv = envAddDef arg (T.TypeScheme [] argVar) env (theta, body') <- infer augEnv body let argType = theta >>! argVar return (theta, T.Lam (T.TFun argType (T.exprType body')) (T.Occ arg argType) body') S.Let _ (name, _) rhs body -> do (theta1, rhs') <- infer env rhs let varType = T.exprType rhs' let augEnv = envAddDef name (T.TypeScheme [] varType) env (theta2, body') <- infer augEnv body return (theta2 <> theta1, T.Let (T.Occ name varType) rhs' body') S.Call sr func arg -> do (theta1, func') <- infer env func (theta2, arg') <- infer (theta1 >>! env) arg resVar <- genTyVar theta3 <- unify sr (theta2 >>! T.exprType func') (T.TFun (T.exprType arg') resVar) return (theta3 <> theta2 <> theta1 ,T.Call (theta3 >>! resVar) ((theta3 <> theta2) >>! func') -- TODO: quadratic complexity (theta3 >>! arg')) -- TODO: quadratic complexity S.Int _ val -> return (mempty, T.Int val) S.Tup _ es -> fmap T.Tup <$> inferList env es S.Var sr name | Just sty <- envFindDef name env -> do ty <- instantiate sty return (mempty, T.Var (T.Occ name ty)) | otherwise -> throwError (RefError sr name) S.Constr sr name -> case envFindType name env of Just (T.TypeDef typname params typ) -> do restyp <- freshenFrees emptyEnv (T.TNamed typname (map (T.TyVar T.Instantiable) params)) return (mempty, T.Constr (T.TFun typ restyp) name) _ -> throwError (RefError sr name) S.Annot sr subex ty -> do (theta1, subex') <- infer env subex ty' <- convertType (envAliases env) sr ty -- Make sure the type of the subexpression matches the type with rigid -- variables, then make it instantiable variables instead for the rest -- of the code. void $ unify sr (T.exprType subex') (replaceRigid ty') theta2 <- unify sr (T.exprType subex') ty' return (theta2 <> theta1, theta2 >>! subex') -- TODO: quadratic complexity inferList :: Env -> [S.Expr] -> TM (Subst, [T.Expr]) inferList _ [] = return (mempty, []) inferList env (expr : exprs) = do (theta, expr') <- infer env expr (theta', res) <- inferList (theta >>! env) exprs return (theta <> theta', expr' : res) runPass :: Context -> S.Program -> Either TCError T.Program runPass (Context _ (Builtins builtins _)) prog = let env = Env (Map.map (generalise emptyEnv) builtins) mempty mempty in runTM (typeCheck env prog) typeCheck :: Env -> S.Program -> TM T.Program typeCheck startEnv (S.Program decls) = do traceM (show decls) let aliasdefs = [(n, def) | S.DeclAlias def@(S.AliasDef (n, _) _ _) <- decls] env1 = envAddAliases (Map.fromList aliasdefs) startEnv typedefs' <- checkTypedefs (envAliases env1) [def | S.DeclType def <- decls] let typedefsMap = Map.fromList [(n, def) | def@(T.TypeDef n _ _) <- typedefs'] let funcdefs = [def | S.DeclFunc def <- decls] typedfuncs <- sequence [(name,) <$> convertType (envAliases env1) sr ty | S.FuncDef (Just ty) (name, sr) _ _ <- funcdefs] let env2 = envAddTypes typedefsMap env1 traceM (show typedefsMap) let env = foldl (\env' (name, ty) -> envAddDef name (generalise env' ty) env') env2 typedfuncs traceM (show env) funcdefs' <- mapM (typeCheckFunc env) funcdefs return (T.Program funcdefs' typedefsMap) typeCheckFunc :: Env -> S.FuncDef -> TM T.Def typeCheckFunc env (S.FuncDef mannot (name, sr) args@(_:_) body) = typeCheckFunc env (S.FuncDef mannot (name, sr) [] (S.Lam sr args body)) typeCheckFunc env (S.FuncDef (Just annot) (name, sr) [] body) = typeCheckFunc env (S.FuncDef Nothing (name, sr) [] (S.Annot sr body annot)) typeCheckFunc env (S.FuncDef Nothing (name, _) [] body) = do (_, body') <- infer env body return (T.Def name body')