aboutsummaryrefslogtreecommitdiff
path: root/typecheck/CC/Typecheck/Types.hs
blob: dc0740d7b96fa3d2b59e1893d1df90b2c0b79212 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
module CC.Typecheck.Types where

import Control.Monad.State.Strict
import Control.Monad.Except
import qualified Data.Map.Strict as Map
import Data.Map.Strict (Map)
import Data.Maybe (fromMaybe)
import qualified Data.Set as Set
import Data.Set (Set)

import qualified CC.AST.Source as S
import qualified CC.AST.Typed as T
import CC.Pretty
import CC.Types


data TCError = UnifyError SourceRange T.Type T.Type T.Type T.Type (Maybe UnifyReason)
             | RefError SourceRange Name
             | TypeArityError SourceRange Name Int Int
             | DupTypeError Name
  deriving (Show)

data UnifyReason = URNotInUnion | URAmbiguousWeakening
  deriving (Show)

instance Pretty TCError where
    pretty (UnifyError sr real expect unifyt1 unifyt2 mreason) =
        "Type error: Expression at " ++ pretty sr ++
            " has type " ++ pretty real ++
            ", but should have type " ++ pretty expect ++
            " (when unifying " ++ pretty unifyt1 ++ " and " ++ pretty unifyt2 ++ ")" ++
            maybe "" (\r -> " (reason: " ++ pretty r ++ ")") mreason
    pretty (RefError sr name) =
        "Reference error: Variable '" ++ name ++ "' out of scope at " ++ pretty sr
    pretty (TypeArityError sr name wanted got) =
        "Type error: Type '" ++ name ++ "' has arity " ++ show wanted ++
            " but gets " ++ show got ++ " type arguments at " ++ pretty sr
    pretty (DupTypeError name) =
        "Duplicate types: Type '" ++ name ++ "' defined multiple times"

instance Pretty UnifyReason where
    pretty URNotInUnion = "type not found in union"
    pretty URAmbiguousWeakening = "type unifies with multiple items in the union"

type TM a = ExceptT TCError (State Int) a

genId :: TM Int
genId = state (\idval -> (idval, idval + 1))

genTyVar :: TM T.Type
genTyVar = T.TyVar T.Instantiable <$> genId

runTM :: TM a -> Either TCError a
runTM m = evalState (runExceptT m) 1


convertType :: Map Name S.AliasDef -> SourceRange -> S.Type -> TM T.Type
convertType aliases sr = fmap snd . convertType' aliases mempty sr

convertTypeDef :: Map Name S.AliasDef -> S.TypeDef -> TM T.TypeDef
convertTypeDef aliases (S.TypeDef (name, sr) args ty) = do
    (mapping, ty') <- convertType' aliases (Set.fromList (map fst args)) sr ty
    let args' = [mapping Map.! n | (n, _) <- args]
    return (T.TypeDef name args' ty')

convertType' :: Map Name S.AliasDef -> Set Name -> SourceRange -> S.Type -> TM (Map Name Int, T.Type)
convertType' aliases extraVars sr origtype = do
    rewritten <- rewrite origtype
    let frees = Set.toList (extraVars <> freeVars rewritten)
    nums <- traverse (const genId) frees
    let mapping = Map.fromList (zip frees nums)
    return (mapping, convert mapping rewritten)
  where
    rewrite :: S.Type -> TM S.Type
    rewrite (S.TFun t1 t2) = S.TFun <$> rewrite t1 <*> rewrite t2
    rewrite S.TInt = return S.TInt
    rewrite (S.TTup ts) = S.TTup <$> mapM rewrite ts
    rewrite (S.TNamed n ts)
      | Just (S.AliasDef _ args typ) <- Map.lookup n aliases =
          if length args == length ts
              then rewrite (subst (Map.fromList (zip (map fst args) ts)) typ)
              else throwError (TypeArityError sr n (length args) (length ts))
      | otherwise =
          S.TNamed n <$> mapM rewrite ts
    rewrite (S.TUnion ts) = S.TUnion . Set.fromList <$> mapM rewrite (Set.toList ts)
    rewrite (S.TyVar n) = return (S.TyVar n)

    -- Substitute type variables
    subst :: Map Name S.Type -> S.Type -> S.Type
    subst mp (S.TFun t1 t2) = S.TFun (subst mp t1) (subst mp t2)
    subst _  S.TInt = S.TInt
    subst mp (S.TTup ts) = S.TTup (map (subst mp) ts)
    subst mp (S.TNamed n ts) = S.TNamed n (map (subst mp) ts)
    subst mp (S.TUnion ts) = S.TUnion (Set.map (subst mp) ts)
    subst mp orig@(S.TyVar n) = fromMaybe orig (Map.lookup n mp)

    freeVars :: S.Type -> Set Name
    freeVars (S.TFun t1 t2) = freeVars t1 <> freeVars t2
    freeVars S.TInt = mempty
    freeVars (S.TTup ts) = Set.unions (map freeVars ts)
    freeVars (S.TNamed _ ts) = Set.unions (map freeVars ts)
    freeVars (S.TUnion ts) = Set.unions (map freeVars (Set.toList ts))
    freeVars (S.TyVar n) = Set.singleton n

    convert :: Map Name Int -> S.Type -> T.Type
    convert mp (S.TFun t1 t2) = T.TFun (convert mp t1) (convert mp t2)
    convert _  S.TInt = T.TInt
    convert mp (S.TTup ts) = T.TTup (map (convert mp) ts)
    convert mp (S.TNamed n ts) = T.TNamed n (map (convert mp) ts)
    convert mp (S.TUnion ts) = T.TUnion (Set.map (convert mp) ts)
    convert mp (S.TyVar n) = T.TyVar T.Instantiable (mp Map.! n)