diff options
Diffstat (limited to 'typecheck/CC/Typecheck')
| -rw-r--r-- | typecheck/CC/Typecheck/Typedefs.hs | 50 | ||||
| -rw-r--r-- | typecheck/CC/Typecheck/Types.hs | 103 | 
2 files changed, 153 insertions, 0 deletions
diff --git a/typecheck/CC/Typecheck/Typedefs.hs b/typecheck/CC/Typecheck/Typedefs.hs new file mode 100644 index 0000000..ad9bdd8 --- /dev/null +++ b/typecheck/CC/Typecheck/Typedefs.hs @@ -0,0 +1,50 @@ +module CC.Typecheck.Typedefs(checkTypedefs) where + +import Control.Monad.Except +import Data.Foldable (traverse_) +import qualified Data.Map.Strict as Map +import Data.Map.Strict (Map) +import qualified Data.Set as Set + +import qualified CC.AST.Source as S +import qualified CC.AST.Typed as T +import CC.Typecheck.Types +import CC.Types + + +checkArity :: Map Name Int -> S.TypeDef -> TM () +checkArity typeArity (S.TypeDef (_, sr) args ty) = go ty +  where +    argNames = map fst args  -- probably a small list + +    go :: S.Type -> TM () +    go (S.TFun t1 t2) = go t1 >> go t2 +    go S.TInt = return () +    go (S.TTup ts) = mapM_ go ts +    go (S.TNamed n ts) +      | Just arity <- Map.lookup n typeArity = +          if length ts == arity +              then mapM_ go ts +              else throwError (TypeArityError sr n arity (length ts)) +      | otherwise = throwError (RefError sr n) +    go (S.TUnion ts) = traverse_ go ts +    go (S.TyVar n) +      | n `elem` argNames = return () +      | otherwise = throwError (RefError sr n) + +checkTypedefs :: Map Name S.AliasDef -> [S.TypeDef] -> TM [T.TypeDef] +checkTypedefs aliases origdefs = do +    let aliasArity = Map.map (\(S.AliasDef _ args _) -> length args) aliases +        typeArity = Map.fromList [(n, length args) +                                 | S.TypeDef (n, _) args _ <- origdefs] + +    let dups = Set.fromList (map (\(S.TypeDef (n, _) _ _) -> n) origdefs) +                    Set.\\ Map.keysSet typeArity +    when (not (Set.null dups)) $ +        throwError (DupTypeError (Set.findMin dups)) + +    let aliasdefs = [S.TypeDef name args typ +                    | S.AliasDef name args typ <- Map.elems aliases] + +    mapM_ (checkArity (aliasArity <> typeArity)) (aliasdefs ++ origdefs) +    mapM (convertTypeDef aliases) origdefs diff --git a/typecheck/CC/Typecheck/Types.hs b/typecheck/CC/Typecheck/Types.hs new file mode 100644 index 0000000..3f3c471 --- /dev/null +++ b/typecheck/CC/Typecheck/Types.hs @@ -0,0 +1,103 @@ +module CC.Typecheck.Types where + +import Control.Monad.State.Strict +import Control.Monad.Except +import qualified Data.Map.Strict as Map +import Data.Map.Strict (Map) +import Data.Maybe (fromMaybe) +import qualified Data.Set as Set +import Data.Set (Set) + +import qualified CC.AST.Source as S +import qualified CC.AST.Typed as T +import CC.Pretty +import CC.Types + + +data TCError = TypeError SourceRange T.Type T.Type +             | RefError SourceRange Name +             | TypeArityError SourceRange Name Int Int +             | DupTypeError Name +  deriving (Show) + +instance Pretty TCError where +    pretty (TypeError sr real expect) = +        "Type error: Expression at " ++ pretty sr ++ +            " has type " ++ pretty real ++ +            ", but should have type " ++ pretty expect +    pretty (RefError sr name) = +        "Reference error: Variable '" ++ name ++ "' out of scope at " ++ pretty sr +    pretty (TypeArityError sr name wanted got) = +        "Type error: Type '" ++ name ++ "' has arity " ++ show wanted ++ +            " but gets " ++ show got ++ " type arguments at " ++ pretty sr +    pretty (DupTypeError name) = +        "Duplicate types: Type '" ++ name ++ "' defined multiple times" + +type TM a = ExceptT TCError (State Int) a + +genId :: TM Int +genId = state (\idval -> (idval, idval + 1)) + +genTyVar :: TM T.Type +genTyVar = T.TyVar T.Instantiable <$> genId + +runTM :: TM a -> Either TCError a +runTM m = evalState (runExceptT m) 1 + + +convertType :: Map Name S.AliasDef -> SourceRange -> S.Type -> TM T.Type +convertType aliases sr = fmap snd . convertType' aliases mempty sr + +convertTypeDef :: Map Name S.AliasDef -> S.TypeDef -> TM T.TypeDef +convertTypeDef aliases (S.TypeDef (name, sr) args ty) = do +    (mapping, ty') <- convertType' aliases (Set.fromList (map fst args)) sr ty +    let args' = [mapping Map.! n | (n, _) <- args] +    return (T.TypeDef name args' ty') + +convertType' :: Map Name S.AliasDef -> Set Name -> SourceRange -> S.Type -> TM (Map Name Int, T.Type) +convertType' aliases extraVars sr origtype = do +    rewritten <- rewrite origtype +    let frees = Set.toList (extraVars <> freeVars rewritten) +    nums <- traverse (const genId) frees +    let mapping = Map.fromList (zip frees nums) +    return (mapping, convert mapping rewritten) +  where +    rewrite :: S.Type -> TM S.Type +    rewrite (S.TFun t1 t2) = S.TFun <$> rewrite t1 <*> rewrite t2 +    rewrite S.TInt = return S.TInt +    rewrite (S.TTup ts) = S.TTup <$> mapM rewrite ts +    rewrite (S.TNamed n ts) +      | Just (S.AliasDef _ args typ) <- Map.lookup n aliases = +          if length args == length ts +              then rewrite (subst (Map.fromList (zip (map fst args) ts)) typ) +              else throwError (TypeArityError sr n (length args) (length ts)) +      | otherwise = +          S.TNamed n <$> mapM rewrite ts +    rewrite (S.TUnion ts) = S.TUnion . Set.fromList <$> mapM rewrite (Set.toList ts) +    rewrite (S.TyVar n) = return (S.TyVar n) + +    -- Substitute type variables +    subst :: Map Name S.Type -> S.Type -> S.Type +    subst mp (S.TFun t1 t2) = S.TFun (subst mp t1) (subst mp t2) +    subst _  S.TInt = S.TInt +    subst mp (S.TTup ts) = S.TTup (map (subst mp) ts) +    subst mp (S.TNamed n ts) = S.TNamed n (map (subst mp) ts) +    subst mp (S.TUnion ts) = S.TUnion (Set.map (subst mp) ts) +    subst mp orig@(S.TyVar n) = fromMaybe orig (Map.lookup n mp) + +    freeVars :: S.Type -> Set Name +    freeVars (S.TFun t1 t2) = freeVars t1 <> freeVars t2 +    freeVars S.TInt = mempty +    freeVars (S.TTup ts) = Set.unions (map freeVars ts) +    freeVars (S.TNamed _ ts) = Set.unions (map freeVars ts) +    freeVars (S.TUnion ts) = Set.unions (map freeVars (Set.toList ts)) +    freeVars (S.TyVar n) = Set.singleton n + +    convert :: Map Name Int -> S.Type -> T.Type +    convert mp (S.TFun t1 t2) = T.TFun (convert mp t1) (convert mp t2) +    convert _  S.TInt = T.TInt +    convert mp (S.TTup ts) = T.TTup (map (convert mp) ts) +    convert mp (S.TNamed n ts) = T.TNamed n (map (convert mp) ts) +    convert mp (S.TUnion ts) = T.TUnion (Set.map (convert mp) ts) +    -- TODO: Should this be Rigid? I really don't know how this works. +    convert mp (S.TyVar n) = T.TyVar T.Rigid (mp Map.! n)  | 
