aboutsummaryrefslogtreecommitdiff
path: root/typecheck/CC/Typecheck/Typedefs.hs
blob: ad9bdd83f4cb86315361edca9f1a5d305df2d68e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
module CC.Typecheck.Typedefs(checkTypedefs) where

import Control.Monad.Except
import Data.Foldable (traverse_)
import qualified Data.Map.Strict as Map
import Data.Map.Strict (Map)
import qualified Data.Set as Set

import qualified CC.AST.Source as S
import qualified CC.AST.Typed as T
import CC.Typecheck.Types
import CC.Types


checkArity :: Map Name Int -> S.TypeDef -> TM ()
checkArity typeArity (S.TypeDef (_, sr) args ty) = go ty
  where
    argNames = map fst args  -- probably a small list

    go :: S.Type -> TM ()
    go (S.TFun t1 t2) = go t1 >> go t2
    go S.TInt = return ()
    go (S.TTup ts) = mapM_ go ts
    go (S.TNamed n ts)
      | Just arity <- Map.lookup n typeArity =
          if length ts == arity
              then mapM_ go ts
              else throwError (TypeArityError sr n arity (length ts))
      | otherwise = throwError (RefError sr n)
    go (S.TUnion ts) = traverse_ go ts
    go (S.TyVar n)
      | n `elem` argNames = return ()
      | otherwise = throwError (RefError sr n)

checkTypedefs :: Map Name S.AliasDef -> [S.TypeDef] -> TM [T.TypeDef]
checkTypedefs aliases origdefs = do
    let aliasArity = Map.map (\(S.AliasDef _ args _) -> length args) aliases
        typeArity = Map.fromList [(n, length args)
                                 | S.TypeDef (n, _) args _ <- origdefs]

    let dups = Set.fromList (map (\(S.TypeDef (n, _) _ _) -> n) origdefs)
                    Set.\\ Map.keysSet typeArity
    when (not (Set.null dups)) $
        throwError (DupTypeError (Set.findMin dups))

    let aliasdefs = [S.TypeDef name args typ
                    | S.AliasDef name args typ <- Map.elems aliases]

    mapM_ (checkArity (aliasArity <> typeArity)) (aliasdefs ++ origdefs)
    mapM (convertTypeDef aliases) origdefs