module CC.Typecheck(runPass) where import Control.Monad.State.Strict import Data.List (intersect) import qualified Data.Map.Strict as Map import CC.Pretty import CC.Source import CC.Typed data TypeError = TypeError SourceRange TypeT TypeT deriving (Show) instance Pretty TypeError where pretty (TypeError sr real expect) = "Type error: Expression at " ++ pretty sr ++ " has type " ++ pretty real ++ ", but should have type " ++ pretty expect type IdSupplyT m a = StateT Int m a genId :: Monad m => IdSupplyT m Int genId = state (\idval -> (idval, idval + 1)) genTyVar :: Monad m => IdSupplyT m TypeT genTyVar = TyVar <$> genId type TM a = IdSupplyT (Either TypeError) a runTM :: TM a -> Either TypeError a runTM m = evalStateT m 1 runPass :: Context -> Program -> Either TypeError ProgramT runPass _ prog = runTM (typeCheck prog) typeCheck :: Program -> TM ProgramT typeCheck (Program decls) = ProgramT <$> mapM typeCheckDL decls typeCheckDL :: Decl -> TM DeclT typeCheckDL (Def def) = DefT <$> typeCheckD def typeCheckD :: Def -> TM DefT typeCheckD (Function mt (fname, fnameR) args body) = do (body', _) <- typeCheckE body return (FunctionT (exprType body') fname (map fst args) body') typeCheckE :: Expr -> TM (ExprT, Mapping) typeCheckE (Call sr func arg) = do (func', m1) <- typeCheckE func (arg', m2) <- typeCheckE arg m <- combine m1 m2 let functype = exprType func' argtype = exprType arg' tvar <- genTyVar apply <- unify (range func) functype (TFunT tvar argtype) let restype = TFunT (apply tvar) (apply argtype) func'' = down apply func' arg'' = down apply arg' return (CallT restype func'' arg'') typeCheckE (Int _ val) = return (IntT val, mempty) typeCheckE (Var _ name) = VarT . Occ name <$> genTyVar -- For each variable, its inferred type and the position of its first -- occurrence in a program fragment. type Mapping = Map.Map Name (TypeT, SourceRange) combine :: Mapping -> Mapping -> TM Mapping combine mp1 mp2 = do let leftmap = Map.filterWithKey (\name _ -> not (Map.member name mp2)) mp1 rightmap = Map.filterWithKey (\name _ -> not (Map.member name mp1)) mp2 overlap = Map.keys mp1 `intersect` Map.keys mp2 combine1 name (t1, sr1) (t2, sr2) | t1 == t2 = Right (t1, sr1) | otherwise = Left (TypeError sr2 t2 t1) midpairs <- sequence [combine1 name (mp1 Map.! name) (mp2 Map.! name) | name <- overlap] return (Map.unions [leftmap, rightmap, Map.fromList midpairs]) unify :: SourceRange -> TypeT -> TypeT -> TM (TypeT -> TypeT) unify = undefined