From 342c213f3caddd64db0eac5ae146912e00378371 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 26 Jul 2020 23:02:09 +0200 Subject: WIP refactor and union types, type variables --- typecheck/CC/Typecheck/Typedefs.hs | 50 ++++++++++++++++++ typecheck/CC/Typecheck/Types.hs | 103 +++++++++++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+) create mode 100644 typecheck/CC/Typecheck/Typedefs.hs create mode 100644 typecheck/CC/Typecheck/Types.hs (limited to 'typecheck/CC/Typecheck') diff --git a/typecheck/CC/Typecheck/Typedefs.hs b/typecheck/CC/Typecheck/Typedefs.hs new file mode 100644 index 0000000..ad9bdd8 --- /dev/null +++ b/typecheck/CC/Typecheck/Typedefs.hs @@ -0,0 +1,50 @@ +module CC.Typecheck.Typedefs(checkTypedefs) where + +import Control.Monad.Except +import Data.Foldable (traverse_) +import qualified Data.Map.Strict as Map +import Data.Map.Strict (Map) +import qualified Data.Set as Set + +import qualified CC.AST.Source as S +import qualified CC.AST.Typed as T +import CC.Typecheck.Types +import CC.Types + + +checkArity :: Map Name Int -> S.TypeDef -> TM () +checkArity typeArity (S.TypeDef (_, sr) args ty) = go ty + where + argNames = map fst args -- probably a small list + + go :: S.Type -> TM () + go (S.TFun t1 t2) = go t1 >> go t2 + go S.TInt = return () + go (S.TTup ts) = mapM_ go ts + go (S.TNamed n ts) + | Just arity <- Map.lookup n typeArity = + if length ts == arity + then mapM_ go ts + else throwError (TypeArityError sr n arity (length ts)) + | otherwise = throwError (RefError sr n) + go (S.TUnion ts) = traverse_ go ts + go (S.TyVar n) + | n `elem` argNames = return () + | otherwise = throwError (RefError sr n) + +checkTypedefs :: Map Name S.AliasDef -> [S.TypeDef] -> TM [T.TypeDef] +checkTypedefs aliases origdefs = do + let aliasArity = Map.map (\(S.AliasDef _ args _) -> length args) aliases + typeArity = Map.fromList [(n, length args) + | S.TypeDef (n, _) args _ <- origdefs] + + let dups = Set.fromList (map (\(S.TypeDef (n, _) _ _) -> n) origdefs) + Set.\\ Map.keysSet typeArity + when (not (Set.null dups)) $ + throwError (DupTypeError (Set.findMin dups)) + + let aliasdefs = [S.TypeDef name args typ + | S.AliasDef name args typ <- Map.elems aliases] + + mapM_ (checkArity (aliasArity <> typeArity)) (aliasdefs ++ origdefs) + mapM (convertTypeDef aliases) origdefs diff --git a/typecheck/CC/Typecheck/Types.hs b/typecheck/CC/Typecheck/Types.hs new file mode 100644 index 0000000..3f3c471 --- /dev/null +++ b/typecheck/CC/Typecheck/Types.hs @@ -0,0 +1,103 @@ +module CC.Typecheck.Types where + +import Control.Monad.State.Strict +import Control.Monad.Except +import qualified Data.Map.Strict as Map +import Data.Map.Strict (Map) +import Data.Maybe (fromMaybe) +import qualified Data.Set as Set +import Data.Set (Set) + +import qualified CC.AST.Source as S +import qualified CC.AST.Typed as T +import CC.Pretty +import CC.Types + + +data TCError = TypeError SourceRange T.Type T.Type + | RefError SourceRange Name + | TypeArityError SourceRange Name Int Int + | DupTypeError Name + deriving (Show) + +instance Pretty TCError where + pretty (TypeError sr real expect) = + "Type error: Expression at " ++ pretty sr ++ + " has type " ++ pretty real ++ + ", but should have type " ++ pretty expect + pretty (RefError sr name) = + "Reference error: Variable '" ++ name ++ "' out of scope at " ++ pretty sr + pretty (TypeArityError sr name wanted got) = + "Type error: Type '" ++ name ++ "' has arity " ++ show wanted ++ + " but gets " ++ show got ++ " type arguments at " ++ pretty sr + pretty (DupTypeError name) = + "Duplicate types: Type '" ++ name ++ "' defined multiple times" + +type TM a = ExceptT TCError (State Int) a + +genId :: TM Int +genId = state (\idval -> (idval, idval + 1)) + +genTyVar :: TM T.Type +genTyVar = T.TyVar T.Instantiable <$> genId + +runTM :: TM a -> Either TCError a +runTM m = evalState (runExceptT m) 1 + + +convertType :: Map Name S.AliasDef -> SourceRange -> S.Type -> TM T.Type +convertType aliases sr = fmap snd . convertType' aliases mempty sr + +convertTypeDef :: Map Name S.AliasDef -> S.TypeDef -> TM T.TypeDef +convertTypeDef aliases (S.TypeDef (name, sr) args ty) = do + (mapping, ty') <- convertType' aliases (Set.fromList (map fst args)) sr ty + let args' = [mapping Map.! n | (n, _) <- args] + return (T.TypeDef name args' ty') + +convertType' :: Map Name S.AliasDef -> Set Name -> SourceRange -> S.Type -> TM (Map Name Int, T.Type) +convertType' aliases extraVars sr origtype = do + rewritten <- rewrite origtype + let frees = Set.toList (extraVars <> freeVars rewritten) + nums <- traverse (const genId) frees + let mapping = Map.fromList (zip frees nums) + return (mapping, convert mapping rewritten) + where + rewrite :: S.Type -> TM S.Type + rewrite (S.TFun t1 t2) = S.TFun <$> rewrite t1 <*> rewrite t2 + rewrite S.TInt = return S.TInt + rewrite (S.TTup ts) = S.TTup <$> mapM rewrite ts + rewrite (S.TNamed n ts) + | Just (S.AliasDef _ args typ) <- Map.lookup n aliases = + if length args == length ts + then rewrite (subst (Map.fromList (zip (map fst args) ts)) typ) + else throwError (TypeArityError sr n (length args) (length ts)) + | otherwise = + S.TNamed n <$> mapM rewrite ts + rewrite (S.TUnion ts) = S.TUnion . Set.fromList <$> mapM rewrite (Set.toList ts) + rewrite (S.TyVar n) = return (S.TyVar n) + + -- Substitute type variables + subst :: Map Name S.Type -> S.Type -> S.Type + subst mp (S.TFun t1 t2) = S.TFun (subst mp t1) (subst mp t2) + subst _ S.TInt = S.TInt + subst mp (S.TTup ts) = S.TTup (map (subst mp) ts) + subst mp (S.TNamed n ts) = S.TNamed n (map (subst mp) ts) + subst mp (S.TUnion ts) = S.TUnion (Set.map (subst mp) ts) + subst mp orig@(S.TyVar n) = fromMaybe orig (Map.lookup n mp) + + freeVars :: S.Type -> Set Name + freeVars (S.TFun t1 t2) = freeVars t1 <> freeVars t2 + freeVars S.TInt = mempty + freeVars (S.TTup ts) = Set.unions (map freeVars ts) + freeVars (S.TNamed _ ts) = Set.unions (map freeVars ts) + freeVars (S.TUnion ts) = Set.unions (map freeVars (Set.toList ts)) + freeVars (S.TyVar n) = Set.singleton n + + convert :: Map Name Int -> S.Type -> T.Type + convert mp (S.TFun t1 t2) = T.TFun (convert mp t1) (convert mp t2) + convert _ S.TInt = T.TInt + convert mp (S.TTup ts) = T.TTup (map (convert mp) ts) + convert mp (S.TNamed n ts) = T.TNamed n (map (convert mp) ts) + convert mp (S.TUnion ts) = T.TUnion (Set.map (convert mp) ts) + -- TODO: Should this be Rigid? I really don't know how this works. + convert mp (S.TyVar n) = T.TyVar T.Rigid (mp Map.! n) -- cgit v1.2.3-70-g09d2