aboutsummaryrefslogtreecommitdiff
path: root/typecheck/CC/Typecheck.hs
blob: 47f42e3deff8feec7a8bdb0f16c5ea206785416e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
module CC.Typecheck(runPass) where

import Control.Monad.State.Strict
import Data.List (intersect)
import qualified Data.Map.Strict as Map

import CC.Pretty
import CC.Source
import CC.Typed


data TypeError = TypeError SourceRange TypeT TypeT
  deriving (Show)

instance Pretty TypeError where
    pretty (TypeError sr real expect) =
        "Type error: Expression at " ++ pretty sr ++ " has type " ++ pretty real ++
            ", but should have type " ++ pretty expect

type IdSupplyT m a = StateT Int m a

genId :: Monad m => IdSupplyT m Int
genId = state (\idval -> (idval, idval + 1))

genTyVar :: Monad m => IdSupplyT m TypeT
genTyVar = TyVar <$> genId

type TM a = IdSupplyT (Either TypeError) a

runTM :: TM a -> Either TypeError a
runTM m = evalStateT m 1


runPass :: Context -> Program -> Either TypeError ProgramT
runPass _ prog = runTM (typeCheck prog)

typeCheck :: Program -> TM ProgramT
typeCheck (Program decls) = ProgramT <$> mapM typeCheckDL decls

typeCheckDL :: Decl -> TM DeclT
typeCheckDL (Def def) = DefT <$> typeCheckD def

typeCheckD :: Def -> TM DefT
typeCheckD (Function mt (fname, fnameR) args body) = do
    (body', _) <- typeCheckE body
    return (FunctionT (exprType body') fname (map fst args) body')

typeCheckE :: Expr -> TM (ExprT, Mapping)
typeCheckE (Call sr func arg) = do
    (func', m1) <- typeCheckE func
    (arg', m2) <- typeCheckE arg
    m <- combine m1 m2

    let functype = exprType func'
        argtype = exprType arg'
    tvar <- genTyVar
    apply <- unify (range func) functype (TFunT tvar argtype)
    let restype = TFunT (apply tvar) (apply argtype)
        func'' = down apply func'
        arg'' = down apply arg'

    return (CallT restype func'' arg'')
typeCheckE (Int _ val) = return (IntT val, mempty)
typeCheckE (Var _ name) = VarT . Occ name <$> genTyVar

-- For each variable, its inferred type and the position of its first
-- occurrence in a program fragment.
type Mapping = Map.Map Name (TypeT, SourceRange)

combine :: Mapping -> Mapping -> TM Mapping
combine mp1 mp2 = do
    let leftmap = Map.filterWithKey (\name _ -> not (Map.member name mp2)) mp1
        rightmap = Map.filterWithKey (\name _ -> not (Map.member name mp1)) mp2
        overlap = Map.keys mp1 `intersect` Map.keys mp2
        combine1 name (t1, sr1) (t2, sr2)
            | t1 == t2 = Right (t1, sr1)
            | otherwise = Left (TypeError sr2 t2 t1)
    midpairs <- sequence [combine1 name (mp1 Map.! name) (mp2 Map.! name)
                         | name <- overlap]
    return (Map.unions [leftmap, rightmap, Map.fromList midpairs])

unify :: SourceRange -> TypeT -> TypeT -> TM (TypeT -> TypeT)
unify = undefined