aboutsummaryrefslogtreecommitdiff
path: root/typecheck/CC/Typecheck/Typedefs.hs
diff options
context:
space:
mode:
Diffstat (limited to 'typecheck/CC/Typecheck/Typedefs.hs')
-rw-r--r--typecheck/CC/Typecheck/Typedefs.hs50
1 files changed, 50 insertions, 0 deletions
diff --git a/typecheck/CC/Typecheck/Typedefs.hs b/typecheck/CC/Typecheck/Typedefs.hs
new file mode 100644
index 0000000..ad9bdd8
--- /dev/null
+++ b/typecheck/CC/Typecheck/Typedefs.hs
@@ -0,0 +1,50 @@
+module CC.Typecheck.Typedefs(checkTypedefs) where
+
+import Control.Monad.Except
+import Data.Foldable (traverse_)
+import qualified Data.Map.Strict as Map
+import Data.Map.Strict (Map)
+import qualified Data.Set as Set
+
+import qualified CC.AST.Source as S
+import qualified CC.AST.Typed as T
+import CC.Typecheck.Types
+import CC.Types
+
+
+checkArity :: Map Name Int -> S.TypeDef -> TM ()
+checkArity typeArity (S.TypeDef (_, sr) args ty) = go ty
+ where
+ argNames = map fst args -- probably a small list
+
+ go :: S.Type -> TM ()
+ go (S.TFun t1 t2) = go t1 >> go t2
+ go S.TInt = return ()
+ go (S.TTup ts) = mapM_ go ts
+ go (S.TNamed n ts)
+ | Just arity <- Map.lookup n typeArity =
+ if length ts == arity
+ then mapM_ go ts
+ else throwError (TypeArityError sr n arity (length ts))
+ | otherwise = throwError (RefError sr n)
+ go (S.TUnion ts) = traverse_ go ts
+ go (S.TyVar n)
+ | n `elem` argNames = return ()
+ | otherwise = throwError (RefError sr n)
+
+checkTypedefs :: Map Name S.AliasDef -> [S.TypeDef] -> TM [T.TypeDef]
+checkTypedefs aliases origdefs = do
+ let aliasArity = Map.map (\(S.AliasDef _ args _) -> length args) aliases
+ typeArity = Map.fromList [(n, length args)
+ | S.TypeDef (n, _) args _ <- origdefs]
+
+ let dups = Set.fromList (map (\(S.TypeDef (n, _) _ _) -> n) origdefs)
+ Set.\\ Map.keysSet typeArity
+ when (not (Set.null dups)) $
+ throwError (DupTypeError (Set.findMin dups))
+
+ let aliasdefs = [S.TypeDef name args typ
+ | S.AliasDef name args typ <- Map.elems aliases]
+
+ mapM_ (checkArity (aliasArity <> typeArity)) (aliasdefs ++ origdefs)
+ mapM (convertTypeDef aliases) origdefs