aboutsummaryrefslogtreecommitdiff
path: root/typecheck/CC/Typecheck
diff options
context:
space:
mode:
Diffstat (limited to 'typecheck/CC/Typecheck')
-rw-r--r--typecheck/CC/Typecheck/Typedefs.hs50
-rw-r--r--typecheck/CC/Typecheck/Types.hs103
2 files changed, 153 insertions, 0 deletions
diff --git a/typecheck/CC/Typecheck/Typedefs.hs b/typecheck/CC/Typecheck/Typedefs.hs
new file mode 100644
index 0000000..ad9bdd8
--- /dev/null
+++ b/typecheck/CC/Typecheck/Typedefs.hs
@@ -0,0 +1,50 @@
+module CC.Typecheck.Typedefs(checkTypedefs) where
+
+import Control.Monad.Except
+import Data.Foldable (traverse_)
+import qualified Data.Map.Strict as Map
+import Data.Map.Strict (Map)
+import qualified Data.Set as Set
+
+import qualified CC.AST.Source as S
+import qualified CC.AST.Typed as T
+import CC.Typecheck.Types
+import CC.Types
+
+
+checkArity :: Map Name Int -> S.TypeDef -> TM ()
+checkArity typeArity (S.TypeDef (_, sr) args ty) = go ty
+ where
+ argNames = map fst args -- probably a small list
+
+ go :: S.Type -> TM ()
+ go (S.TFun t1 t2) = go t1 >> go t2
+ go S.TInt = return ()
+ go (S.TTup ts) = mapM_ go ts
+ go (S.TNamed n ts)
+ | Just arity <- Map.lookup n typeArity =
+ if length ts == arity
+ then mapM_ go ts
+ else throwError (TypeArityError sr n arity (length ts))
+ | otherwise = throwError (RefError sr n)
+ go (S.TUnion ts) = traverse_ go ts
+ go (S.TyVar n)
+ | n `elem` argNames = return ()
+ | otherwise = throwError (RefError sr n)
+
+checkTypedefs :: Map Name S.AliasDef -> [S.TypeDef] -> TM [T.TypeDef]
+checkTypedefs aliases origdefs = do
+ let aliasArity = Map.map (\(S.AliasDef _ args _) -> length args) aliases
+ typeArity = Map.fromList [(n, length args)
+ | S.TypeDef (n, _) args _ <- origdefs]
+
+ let dups = Set.fromList (map (\(S.TypeDef (n, _) _ _) -> n) origdefs)
+ Set.\\ Map.keysSet typeArity
+ when (not (Set.null dups)) $
+ throwError (DupTypeError (Set.findMin dups))
+
+ let aliasdefs = [S.TypeDef name args typ
+ | S.AliasDef name args typ <- Map.elems aliases]
+
+ mapM_ (checkArity (aliasArity <> typeArity)) (aliasdefs ++ origdefs)
+ mapM (convertTypeDef aliases) origdefs
diff --git a/typecheck/CC/Typecheck/Types.hs b/typecheck/CC/Typecheck/Types.hs
new file mode 100644
index 0000000..3f3c471
--- /dev/null
+++ b/typecheck/CC/Typecheck/Types.hs
@@ -0,0 +1,103 @@
+module CC.Typecheck.Types where
+
+import Control.Monad.State.Strict
+import Control.Monad.Except
+import qualified Data.Map.Strict as Map
+import Data.Map.Strict (Map)
+import Data.Maybe (fromMaybe)
+import qualified Data.Set as Set
+import Data.Set (Set)
+
+import qualified CC.AST.Source as S
+import qualified CC.AST.Typed as T
+import CC.Pretty
+import CC.Types
+
+
+data TCError = TypeError SourceRange T.Type T.Type
+ | RefError SourceRange Name
+ | TypeArityError SourceRange Name Int Int
+ | DupTypeError Name
+ deriving (Show)
+
+instance Pretty TCError where
+ pretty (TypeError sr real expect) =
+ "Type error: Expression at " ++ pretty sr ++
+ " has type " ++ pretty real ++
+ ", but should have type " ++ pretty expect
+ pretty (RefError sr name) =
+ "Reference error: Variable '" ++ name ++ "' out of scope at " ++ pretty sr
+ pretty (TypeArityError sr name wanted got) =
+ "Type error: Type '" ++ name ++ "' has arity " ++ show wanted ++
+ " but gets " ++ show got ++ " type arguments at " ++ pretty sr
+ pretty (DupTypeError name) =
+ "Duplicate types: Type '" ++ name ++ "' defined multiple times"
+
+type TM a = ExceptT TCError (State Int) a
+
+genId :: TM Int
+genId = state (\idval -> (idval, idval + 1))
+
+genTyVar :: TM T.Type
+genTyVar = T.TyVar T.Instantiable <$> genId
+
+runTM :: TM a -> Either TCError a
+runTM m = evalState (runExceptT m) 1
+
+
+convertType :: Map Name S.AliasDef -> SourceRange -> S.Type -> TM T.Type
+convertType aliases sr = fmap snd . convertType' aliases mempty sr
+
+convertTypeDef :: Map Name S.AliasDef -> S.TypeDef -> TM T.TypeDef
+convertTypeDef aliases (S.TypeDef (name, sr) args ty) = do
+ (mapping, ty') <- convertType' aliases (Set.fromList (map fst args)) sr ty
+ let args' = [mapping Map.! n | (n, _) <- args]
+ return (T.TypeDef name args' ty')
+
+convertType' :: Map Name S.AliasDef -> Set Name -> SourceRange -> S.Type -> TM (Map Name Int, T.Type)
+convertType' aliases extraVars sr origtype = do
+ rewritten <- rewrite origtype
+ let frees = Set.toList (extraVars <> freeVars rewritten)
+ nums <- traverse (const genId) frees
+ let mapping = Map.fromList (zip frees nums)
+ return (mapping, convert mapping rewritten)
+ where
+ rewrite :: S.Type -> TM S.Type
+ rewrite (S.TFun t1 t2) = S.TFun <$> rewrite t1 <*> rewrite t2
+ rewrite S.TInt = return S.TInt
+ rewrite (S.TTup ts) = S.TTup <$> mapM rewrite ts
+ rewrite (S.TNamed n ts)
+ | Just (S.AliasDef _ args typ) <- Map.lookup n aliases =
+ if length args == length ts
+ then rewrite (subst (Map.fromList (zip (map fst args) ts)) typ)
+ else throwError (TypeArityError sr n (length args) (length ts))
+ | otherwise =
+ S.TNamed n <$> mapM rewrite ts
+ rewrite (S.TUnion ts) = S.TUnion . Set.fromList <$> mapM rewrite (Set.toList ts)
+ rewrite (S.TyVar n) = return (S.TyVar n)
+
+ -- Substitute type variables
+ subst :: Map Name S.Type -> S.Type -> S.Type
+ subst mp (S.TFun t1 t2) = S.TFun (subst mp t1) (subst mp t2)
+ subst _ S.TInt = S.TInt
+ subst mp (S.TTup ts) = S.TTup (map (subst mp) ts)
+ subst mp (S.TNamed n ts) = S.TNamed n (map (subst mp) ts)
+ subst mp (S.TUnion ts) = S.TUnion (Set.map (subst mp) ts)
+ subst mp orig@(S.TyVar n) = fromMaybe orig (Map.lookup n mp)
+
+ freeVars :: S.Type -> Set Name
+ freeVars (S.TFun t1 t2) = freeVars t1 <> freeVars t2
+ freeVars S.TInt = mempty
+ freeVars (S.TTup ts) = Set.unions (map freeVars ts)
+ freeVars (S.TNamed _ ts) = Set.unions (map freeVars ts)
+ freeVars (S.TUnion ts) = Set.unions (map freeVars (Set.toList ts))
+ freeVars (S.TyVar n) = Set.singleton n
+
+ convert :: Map Name Int -> S.Type -> T.Type
+ convert mp (S.TFun t1 t2) = T.TFun (convert mp t1) (convert mp t2)
+ convert _ S.TInt = T.TInt
+ convert mp (S.TTup ts) = T.TTup (map (convert mp) ts)
+ convert mp (S.TNamed n ts) = T.TNamed n (map (convert mp) ts)
+ convert mp (S.TUnion ts) = T.TUnion (Set.map (convert mp) ts)
+ -- TODO: Should this be Rigid? I really don't know how this works.
+ convert mp (S.TyVar n) = T.TyVar T.Rigid (mp Map.! n)