module CC.Typecheck.Typedefs(checkTypedefs) where import Control.Monad.Except import Data.Foldable (traverse_) import qualified Data.Map.Strict as Map import Data.Map.Strict (Map) import qualified Data.Set as Set import qualified CC.AST.Source as S import qualified CC.AST.Typed as T import CC.Typecheck.Types import CC.Types checkArity :: Map Name Int -> S.TypeDef -> TM () checkArity typeArity (S.TypeDef (_, sr) args ty) = go ty where argNames = map fst args -- probably a small list go :: S.Type -> TM () go (S.TFun t1 t2) = go t1 >> go t2 go S.TInt = return () go (S.TTup ts) = mapM_ go ts go (S.TNamed n ts) | Just arity <- Map.lookup n typeArity = if length ts == arity then mapM_ go ts else throwError (TypeArityError sr n arity (length ts)) | otherwise = throwError (RefError sr n) go (S.TUnion ts) = traverse_ go ts go (S.TyVar n) | n `elem` argNames = return () | otherwise = throwError (RefError sr n) checkTypedefs :: Map Name S.AliasDef -> [S.TypeDef] -> TM [T.TypeDef] checkTypedefs aliases origdefs = do let aliasArity = Map.map (\(S.AliasDef _ args _) -> length args) aliases typeArity = Map.fromList [(n, length args) | S.TypeDef (n, _) args _ <- origdefs] let dups = Set.fromList (map (\(S.TypeDef (n, _) _ _) -> n) origdefs) Set.\\ Map.keysSet typeArity when (not (Set.null dups)) $ throwError (DupTypeError (Set.findMin dups)) let aliasdefs = [S.TypeDef name args typ | S.AliasDef name args typ <- Map.elems aliases] mapM_ (checkArity (aliasArity <> typeArity)) (aliasdefs ++ origdefs) mapM (convertTypeDef aliases) origdefs