diff options
| -rw-r--r-- | ast/CC/AST/Source.hs | 37 | ||||
| -rw-r--r-- | ast/CC/AST/Typed.hs | 48 | ||||
| -rw-r--r-- | ast/CC/Context.hs | 2 | ||||
| -rw-r--r-- | backend/CC/Backend/Dumb.hs | 24 | ||||
| -rw-r--r-- | compcomp.cabal | 3 | ||||
| -rw-r--r-- | parser/CC/Parser.hs | 97 | ||||
| -rw-r--r-- | typecheck/CC/Typecheck.hs | 180 | ||||
| -rw-r--r-- | typecheck/CC/Typecheck/Typedefs.hs | 50 | ||||
| -rw-r--r-- | typecheck/CC/Typecheck/Types.hs | 103 | 
9 files changed, 434 insertions, 110 deletions
diff --git a/ast/CC/AST/Source.hs b/ast/CC/AST/Source.hs index e648759..e64e058 100644 --- a/ast/CC/AST/Source.hs +++ b/ast/CC/AST/Source.hs @@ -3,6 +3,8 @@ module CC.AST.Source(      module CC.Types  ) where +import qualified Data.Set as Set +import Data.Set (Set)  import Data.List  import CC.Pretty @@ -12,19 +14,32 @@ import CC.Types  data Program = Program [Decl]    deriving (Show, Read) -data Decl = Def Def  -- import? +data Decl = DeclFunc FuncDef +          | DeclType TypeDef +          | DeclAlias AliasDef    deriving (Show, Read) -data Def = Function (Maybe Type) -                    (Name, SourceRange) -                    [(Name, SourceRange)] -                    Expr +data FuncDef = +    FuncDef (Maybe Type) +            (Name, SourceRange) +            [(Name, SourceRange)] +            Expr +  deriving (Show, Read) + +-- Named type with named arguments +data TypeDef = TypeDef (Name, SourceRange) [(Name, SourceRange)] Type +  deriving (Show, Read) + +data AliasDef = AliasDef (Name, SourceRange) [(Name, SourceRange)] Type    deriving (Show, Read)  data Type = TFun Type Type            | TInt            | TTup [Type] -  deriving (Show, Read) +          | TNamed Name [Type]  -- named type with type arguments +          | TUnion (Set Type) +          | TyVar Name +  deriving (Eq, Ord, Show, Read)  data Expr = Lam SourceRange [(Name, SourceRange)] Expr            | Let SourceRange (Name, SourceRange) Expr Expr @@ -32,15 +47,24 @@ data Expr = Lam SourceRange [(Name, SourceRange)] Expr            | Int SourceRange Int            | Tup SourceRange [Expr]            | Var SourceRange Name +          | Constr SourceRange Name  -- type constructor            | Annot SourceRange Expr Type    deriving (Show, Read) +instance Semigroup Program where Program p1 <> Program p2 = Program (p1 <> p2) +instance Monoid    Program where mempty                   = Program mempty +  instance Pretty Type where      prettyPrec _ TInt = "Int"      prettyPrec p (TFun a b) =          precParens p 2 (prettyPrec 3 a ++ " -> " ++ prettyPrec 2 b)      prettyPrec _ (TTup ts) =          "(" ++ intercalate ", " (map pretty ts) ++ ")" +    prettyPrec _ (TNamed n ts) = +        n ++ "[" ++ intercalate ", " (map pretty ts) ++ "]" +    prettyPrec _ (TUnion ts) = +        "{" ++ intercalate " | " (map pretty (Set.toList ts)) ++ "}" +    prettyPrec _ (TyVar n) = "<" ++ n ++ ">"  instance HasRange Expr where      range (Lam sr _ _) = sr @@ -49,4 +73,5 @@ instance HasRange Expr where      range (Int sr _) = sr      range (Tup sr _) = sr      range (Var sr _) = sr +    range (Constr sr _) = sr      range (Annot sr _ _) = sr diff --git a/ast/CC/AST/Typed.hs b/ast/CC/AST/Typed.hs index b12b30a..cf67575 100644 --- a/ast/CC/AST/Typed.hs +++ b/ast/CC/AST/Typed.hs @@ -3,23 +3,36 @@ module CC.AST.Typed(      module CC.Types  ) where +import qualified Data.Map.Strict as Map +import Data.Map.Strict (Map) +import qualified Data.Set as Set +import Data.Set (Set)  import Data.List  import CC.Pretty  import CC.Types -data Program = Program [Def] +data Program = Program [Def] (Map Name TypeDef)    deriving (Show, Read)  data Def = Def Name Expr    deriving (Show, Read) +-- Named type with type parameters +data TypeDef = TypeDef Name [Int] Type +  deriving (Show, Read) +  data Type = TFun Type Type            | TInt            | TTup [Type] -          | TyVar Int -  deriving (Show, Read) +          | TNamed Name [Type]  -- named type with type arguments +          | TUnion (Set Type) +          | TyVar Rigidity Int +  deriving (Eq, Ord, Show, Read) + +data Rigidity = Rigid | Instantiable +  deriving (Eq, Ord, Show, Read)  data TypeScheme = TypeScheme [Int] Type    deriving (Show, Read) @@ -30,6 +43,7 @@ data Expr = Lam Type Occ Expr            | Int Int            | Tup [Expr]            | Var Occ +          | Constr Type Name  -- Type is 'argument -> TNamed'    deriving (Show, Read)  data Occ = Occ Name Type @@ -42,6 +56,7 @@ exprType (Call typ _ _) = typ  exprType (Int _) = TInt  exprType (Tup es) = TTup (map exprType es)  exprType (Var (Occ _ typ)) = typ +exprType (Constr typ _) = typ  instance Pretty Type where      prettyPrec _ TInt = "Int" @@ -49,13 +64,18 @@ instance Pretty Type where          precParens p 2 (prettyPrec 3 a ++ " -> " ++ prettyPrec 2 b)      prettyPrec _ (TTup ts) =          "(" ++ intercalate ", " (map pretty ts) ++ ")" -    prettyPrec _ (TyVar i) = 't' : show i +    prettyPrec _ (TNamed n ts) = +        n ++ "[" ++ intercalate ", " (map pretty ts) ++ "]" +    prettyPrec _ (TUnion ts) = +        "{ " ++ intercalate " | " (map pretty (Set.toList ts)) ++ " }" +    prettyPrec _ (TyVar Rigid i) = 't' : show i ++ "R" +    prettyPrec _ (TyVar Instantiable i) = 't' : show i  instance Pretty TypeScheme where      prettyPrec p (TypeScheme bnds ty) =          precParens p 2 -            ("forall " ++ intercalate " " (map (pretty . TyVar) bnds) ++ ". " ++ -                prettyPrec 2 ty) +            ("forall " ++ intercalate " " (map (pretty . TyVar Instantiable) bnds) ++ +                ". " ++ prettyPrec 2 ty)  instance Pretty Expr where      prettyPrec p (Lam ty (Occ n t) e) = @@ -75,10 +95,22 @@ instance Pretty Expr where      prettyPrec _ (Tup es) = "(" ++ intercalate ", " (map pretty es) ++ ")"      prettyPrec p (Var (Occ n t)) =          precParens p 2 $ -            show n ++ " :: " ++ pretty t +            n ++ " :: " ++ pretty t +    prettyPrec p (Constr t n) = +        precParens p 2 $ +            n ++ " :: " ++ pretty t  instance Pretty Def where      pretty (Def n e) = n ++ " = " ++ pretty e +instance Pretty TypeDef where +    pretty (TypeDef n [] t) = +        "type " ++ n ++ " = " ++ pretty t +    pretty (TypeDef n vs t) = +        "type " ++ n ++ " " ++ intercalate " " (map (pretty . TyVar Instantiable) vs) ++ +            " = " ++ pretty t +  instance Pretty Program where -    pretty (Program defs) = concatMap ((++ "\n") . pretty) defs +    pretty (Program defs tdefs) = +        concatMap (++ "\n") $ +            map pretty (Map.elems tdefs) ++ map pretty defs diff --git a/ast/CC/Context.hs b/ast/CC/Context.hs index 68378d7..acfb614 100644 --- a/ast/CC/Context.hs +++ b/ast/CC/Context.hs @@ -9,4 +9,4 @@ import CC.AST.Typed  data Context = Context FilePath Builtins  -- | Information about builtins supported by the enabled backend -data Builtins = Builtins (Map Name Type) +data Builtins = Builtins (Map Name Type) String diff --git a/backend/CC/Backend/Dumb.hs b/backend/CC/Backend/Dumb.hs index 3210dab..822fb84 100644 --- a/backend/CC/Backend/Dumb.hs +++ b/backend/CC/Backend/Dumb.hs @@ -7,10 +7,20 @@ import CC.Context  builtins :: Builtins -builtins = Builtins . Map.fromList $ -    [ ("print", TFun TInt (TTup [])) -    , ("fst", TFun (TTup [TyVar 1, TyVar 2]) (TyVar 1)) -    , ("snd", TFun (TTup [TyVar 1, TyVar 2]) (TyVar 2)) -    , ("_add", TFun TInt (TFun TInt TInt)) -    , ("_sub", TFun TInt (TFun TInt TInt)) -    , ("_mul", TFun TInt (TFun TInt TInt)) ] +builtins = +    let values = Map.fromList +            [ ("print", TInt ==> TTup []) +            , ("fst", TTup [t1, t2] ==> t1) +            , ("snd", TTup [t1, t2] ==> t2) +            , ("_add", TInt ==> TInt ==> TInt) +            , ("_sub", TInt ==> TInt ==> TInt) +            , ("_mul", TInt ==> TInt ==> TInt) ] +        prelude = "type Nil = ()\n\ +                  \type Cons a = (a, List a)\n\ +                  \alias List a = { Nil | Cons a }\n" +    in Builtins values prelude +  where +    t1 = TyVar Instantiable 1 +    t2 = TyVar Instantiable 2 +    infixr ==> +    (==>) = TFun diff --git a/compcomp.cabal b/compcomp.cabal index 9e941aa..4b06f4b 100644 --- a/compcomp.cabal +++ b/compcomp.cabal @@ -20,7 +20,7 @@ executable            compcomp  library               cc-parser      import:           deps      hs-source-dirs:   parser -    build-depends:    parsec, cc-ast, cc-utils +    build-depends:    containers, parsec, cc-ast, cc-utils      exposed-modules:  CC.Parser  library               cc-typecheck @@ -28,6 +28,7 @@ library               cc-typecheck      hs-source-dirs:   typecheck      build-depends:    containers, mtl, cc-ast, cc-utils      exposed-modules:  CC.Typecheck +    other-modules:    CC.Typecheck.Typedefs, CC.Typecheck.Types  library               cc-backend-dumb      import:           deps diff --git a/parser/CC/Parser.hs b/parser/CC/Parser.hs index 2d2c4b7..66bc6cf 100644 --- a/parser/CC/Parser.hs +++ b/parser/CC/Parser.hs @@ -1,6 +1,7 @@  module CC.Parser(runPass, parseProgram) where  import Control.Monad +import qualified Data.Set as Set  import Text.Parsec hiding (SourcePos, getPosition, token)  import qualified Text.Parsec @@ -12,7 +13,10 @@ import CC.Pretty  type Parser a = Parsec String () a  runPass :: Context -> RawString -> Either (PrettyShow ParseError) Program -runPass (Context path _) (RawString src) = fmapLeft PrettyShow (parseProgram path src) +runPass (Context path (Builtins _ prelude)) (RawString src) = do +    prog1 <- fmapLeft PrettyShow (parseProgram "<prelude>" prelude) +    prog2 <- fmapLeft PrettyShow (parseProgram path src) +    return (prog1 <> prog2)    where fmapLeft f (Left x) = Left (f x)          fmapLeft _ (Right x) = Right x @@ -27,32 +31,63 @@ pProgram = do      return prog  pDecl :: Parser Decl -pDecl = Def <$> pDef +pDecl = choice +    [ DeclType <$> pDeclType +    , DeclAlias <$> pDeclAlias  +    , DeclFunc <$> pDeclFunc ] -pDef :: Parser Def -pDef = do +pDeclFunc :: Parser FuncDef +pDeclFunc = do      func <- try $ do          emptyLines -        name <- pName0 <?> "declaration head name" +        name <- pName0 LowerCase <?> "declaration head name"          return name      mtyp <- optionMaybe $ do          symbol "::"          typ <- pType          whitespace >> void newline          emptyLines -        func' <- fst <$> pName0 +        (func', _) <- pName0 LowerCase          guard (fst func == func')          return typ -    args <- many pName +    args <- many (pName LowerCase)      symbol "="      expr <- pExpr -    return (Function mtyp func args expr) +    return (FuncDef mtyp func args expr) + +pDeclType :: Parser TypeDef +pDeclType = (\(n, a, t) -> TypeDef n a t) <$> pTypedefLike "type" + +pDeclAlias :: Parser AliasDef +pDeclAlias = (\(n, a, t) -> AliasDef n a t) <$> pTypedefLike "alias" + +pTypedefLike :: String -> Parser ((Name, SourceRange), [(Name, SourceRange)], Type) +pTypedefLike keyword = do +    try (emptyLines >> string keyword >> whitespace1) +    name <- pName0 UpperCase +    args <- many (pName LowerCase) +    symbol "=" +    ty <- pType +    return (name, args, ty)  pType :: Parser Type -pType = chainr1 pTypeAtom (symbol "->" >> return TFun) +pType = chainr1 pTypeTerm (symbol "->" >> return TFun) + +pTypeTerm :: Parser Type +pTypeTerm = pTypeAtom <|> pTypeCall  pTypeAtom :: Parser Type -pTypeAtom = (wordToken "Int" >> return TInt) <|> pParenType +pTypeAtom = choice +    [ wordToken "Int" >> return TInt +    , TyVar . fst <$> pName LowerCase +    , pParenType +    , pUnionType ] + +pTypeCall :: Parser Type +pTypeCall = do +    (constr, _) <- pName UpperCase +    args <- many pTypeAtom +    return (TNamed constr args)  pParenType :: Parser Type  pParenType = do @@ -63,6 +98,19 @@ pParenType = do          [ty] -> return ty          _ -> return (TTup tys) +pUnionType :: Parser Type +pUnionType = do +    token "{" +    tys <- pType `sepBy` token "|" +    token "}" +    case tys of +        [] -> unexpected "empty union type" +        [ty] -> return ty +        _ -> let tyset = Set.fromList tys +             in if Set.size tyset == length tys +                    then return (TUnion tyset) +                    else unexpected "duplicate types in union" +  pExpr :: Parser Expr  pExpr = label (pLam <|> pLet <|> pCall) "expression"    where @@ -84,7 +132,7 @@ pExpr = label (pLam <|> pLet <|> pCall) "expression"              p <- getPosition              void (char '\\')              return p -        names <- many1 pName +        names <- many1 (pName LowerCase)          symbol "->"          body <- pExpr          p2 <- getPosition @@ -100,7 +148,7 @@ pExpr = label (pLam <|> pLet <|> pCall) "expression"        where          afterKeyword p1 = do              whitespace1 -            lhs <- pName0 +            lhs <- pName0 LowerCase              symbol "="              rhs <- pExpr              let fullRange rest = mergeRange (SourceRange p1 p1) (range rest) @@ -118,7 +166,8 @@ pExpr = label (pLam <|> pLet <|> pCall) "expression"  pExprAtom :: Parser Expr  pExprAtom =      choice [ uncurry (flip Int) <$> pInt -           , uncurry (flip Var) <$> pName +           , uncurry (flip Var) <$> pName LowerCase +           , uncurry (flip Constr) <$> pName UpperCase             , pParenExpr ]  pParenExpr :: Parser Expr @@ -141,11 +190,13 @@ pInt = try (whitespace >> pInt0)          p2 <- getPosition          return (num, SourceRange p1 p2) -pName0 :: Parser (Name, SourceRange) -pName0 = do +data Case = LowerCase | UpperCase + +pName0 :: Case -> Parser (Name, SourceRange) +pName0 vcase = do      p1 <- getPosition      s <- try $ do -        c <- pWordFirstChar +        c <- pWordFirstChar vcase          cs <- many pWordMidChar          let s = c : cs          guard (s `notElem` ["let", "in"]) @@ -154,14 +205,18 @@ pName0 = do      notFollowedBy pWordMidChar      return (s, SourceRange p1 p2) -pWordFirstChar :: Parser Char -pWordFirstChar = letter <|> oneOf "_$#!" +pWordFirstChar :: Case -> Parser Char +pWordFirstChar LowerCase = lower <|> oneOf wordSymbols +pWordFirstChar UpperCase = upper <|> oneOf wordSymbols  pWordMidChar :: Parser Char -pWordMidChar = alphaNum <|> oneOf "_$#!" +pWordMidChar = alphaNum <|> oneOf wordSymbols + +wordSymbols :: [Char] +wordSymbols = "_$#!" -pName :: Parser (Name, SourceRange) -pName = try (whitespace >> pName0) +pName :: Case -> Parser (Name, SourceRange) +pName vcase = try (whitespace >> pName0 vcase)  symbol :: String -> Parser ()  symbol s = token s >> (eof <|> void space <|> void (oneOf "(){}[]")) diff --git a/typecheck/CC/Typecheck.hs b/typecheck/CC/Typecheck.hs index f61103e..824a714 100644 --- a/typecheck/CC/Typecheck.hs +++ b/typecheck/CC/Typecheck.hs @@ -1,4 +1,5 @@  {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TupleSections #-}  {-# LANGUAGE TypeSynonymInstances #-}  module CC.Typecheck(runPass) where @@ -10,59 +11,47 @@ import Data.Maybe  import qualified Data.Set as Set  import Data.Set (Set) +import Debug.Trace +  import qualified CC.AST.Source as S  import qualified CC.AST.Typed as T  import CC.Context -import CC.Pretty  import CC.Types +import CC.Typecheck.Typedefs +import CC.Typecheck.Types  -- Inspiration: https://github.com/kritzcreek/fby19 -data TCError = TypeError SourceRange T.Type T.Type -             | RefError SourceRange Name +data Env = +    Env (Map Name T.TypeScheme)  -- Definitions in scope +        (Map Name T.TypeDef)     -- Type definitions +        (Map Name S.AliasDef)    -- Type aliases    deriving (Show) -instance Pretty TCError where -    pretty (TypeError sr real expect) = -        "Type error: Expression at " ++ pretty sr ++ -            " has type " ++ pretty real ++ -            ", but should have type " ++ pretty expect -    pretty (RefError sr name) = -        "Reference error: Variable '" ++ name ++ "' out of scope at " ++ pretty sr - -type TM a = ExceptT TCError (State Int) a - -genId :: TM Int -genId = state (\idval -> (idval, idval + 1)) - -genTyVar :: TM T.Type -genTyVar = T.TyVar <$> genId - -runTM :: TM a -> Either TCError a -runTM m = evalState (runExceptT m) 1 - - -newtype Env = Env (Map Name T.TypeScheme) -  newtype Subst = Subst (Map Int T.Type)  class FreeTypeVars a where -    freeTypeVars :: a -> Set Int +    -- Free instantiable type variables +    freeInstTypeVars :: a -> Set Int  instance FreeTypeVars T.Type where -    freeTypeVars (T.TFun t1 t2) = freeTypeVars t1 <> freeTypeVars t2 -    freeTypeVars T.TInt = mempty -    freeTypeVars (T.TTup ts) = Set.unions (map freeTypeVars ts) -    freeTypeVars (T.TyVar var) = Set.singleton var +    freeInstTypeVars (T.TFun t1 t2) = freeInstTypeVars t1 <> freeInstTypeVars t2 +    freeInstTypeVars T.TInt = mempty +    freeInstTypeVars (T.TTup ts) = Set.unions (map freeInstTypeVars ts) +    freeInstTypeVars (T.TNamed _ ts) = Set.unions (map freeInstTypeVars ts) +    freeInstTypeVars (T.TUnion ts) = Set.unions (map freeInstTypeVars (Set.toList ts)) +    freeInstTypeVars (T.TyVar T.Instantiable var) = Set.singleton var +    freeInstTypeVars (T.TyVar T.Rigid _) = mempty  instance FreeTypeVars T.TypeScheme where -    freeTypeVars (T.TypeScheme bnds ty) = foldr Set.delete (freeTypeVars ty) bnds +    freeInstTypeVars (T.TypeScheme bnds ty) = +        foldr Set.delete (freeInstTypeVars ty) bnds  instance FreeTypeVars Env where -    freeTypeVars (Env mp) = foldMap freeTypeVars (Map.elems mp) +    freeInstTypeVars (Env mp _ _) = foldMap freeInstTypeVars (Map.elems mp)  infixr >>! @@ -74,14 +63,20 @@ instance Substitute T.Type where          T.TFun t1 t2 -> T.TFun (theta >>! t1) (theta >>! t2)          T.TInt -> T.TInt          T.TTup ts -> T.TTup (map (theta >>!) ts) -        T.TyVar i -> fromMaybe ty (Map.lookup i mp) +        T.TNamed n ts -> T.TNamed n (map (theta >>!) ts) +        T.TUnion ts -> T.TUnion (Set.map (theta >>!) ts) +        T.TyVar T.Instantiable i -> fromMaybe ty (Map.lookup i mp) +        T.TyVar T.Rigid i +          | i `Map.member` mp -> error "Attempt to substitute a rigid type variable" +          | otherwise -> ty  instance Substitute T.TypeScheme where      Subst mp >>! T.TypeScheme bnds ty =          T.TypeScheme bnds (Subst (foldr Map.delete mp bnds) >>! ty)  instance Substitute Env where -    theta >>! Env mp = Env (Map.map (theta >>!) mp) +    theta >>! Env mp tdefs aliases = +        Env (Map.map (theta >>!) mp) tdefs aliases  -- TODO: make this instance unnecessary  instance Substitute T.Expr where @@ -94,6 +89,7 @@ instance Substitute T.Expr where      _     >>! expr@(T.Int _) = expr      theta >>! T.Tup es = T.Tup (map (theta >>!) es)      theta >>! T.Var (T.Occ name ty) = T.Var (T.Occ name (theta >>! ty)) +    theta >>! T.Constr ty n = T.Constr (theta >>! ty) n  instance Semigroup Subst where @@ -103,20 +99,43 @@ instance Monoid Subst where      mempty = Subst mempty  emptyEnv :: Env -emptyEnv = Env mempty +emptyEnv = Env mempty mempty mempty + +envAddDef :: Name -> T.TypeScheme -> Env -> Env +envAddDef name sty (Env mp tmp aliases) +  | name `Map.member` mp = error "envAddDef on name already in environment" +  | otherwise = +      Env (Map.insert name sty mp) tmp aliases + +envFindDef :: Name -> Env -> Maybe T.TypeScheme +envFindDef name (Env mp _ _) = Map.lookup name mp + +envAddTypes :: Map Name T.TypeDef -> Env -> Env +envAddTypes l (Env mp tdefs aliases) = +    let combined = l <> tdefs +    in if Map.size combined == Map.size l + Map.size tdefs +           then Env mp combined aliases +           else error "envAddTypes on duplicate type names" -envAdd :: Name -> T.TypeScheme -> Env -> Env -envAdd name sty (Env mp) = Env (Map.insert name sty mp) +envFindType :: Name -> Env -> Maybe T.TypeDef +envFindType name (Env _ tdefs _) = Map.lookup name tdefs -envFind :: Name -> Env -> Maybe T.TypeScheme -envFind name (Env mp) = Map.lookup name mp +envAddAliases :: Map Name S.AliasDef -> Env -> Env +envAddAliases l (Env mp tdefs aliases) = +    let combined = l <> aliases +    in if Map.size combined == Map.size l + Map.size aliases +           then Env mp tdefs combined +           else error "envAddAliaes on duplicate type names" + +envAliases :: Env -> Map Name S.AliasDef +envAliases (Env _ _ aliases) = aliases  substVar :: Int -> T.Type -> Subst  substVar var ty = Subst (Map.singleton var ty)  generalise :: Env -> T.Type -> T.TypeScheme  generalise env ty = -    T.TypeScheme (Set.toList (freeTypeVars ty Set.\\ freeTypeVars env)) ty +    T.TypeScheme (Set.toList (freeInstTypeVars ty Set.\\ freeInstTypeVars env)) ty  instantiate :: T.TypeScheme -> TM T.Type  instantiate (T.TypeScheme bnds ty) = do @@ -124,6 +143,9 @@ instantiate (T.TypeScheme bnds ty) = do      let theta = Subst (Map.fromList (zip bnds vars))      return (theta >>! ty) +freshenFrees :: Env -> T.Type -> TM T.Type +freshenFrees env = instantiate . generalise env +  data UnifyContext = UnifyContext SourceRange T.Type T.Type  unify :: SourceRange -> T.Type -> T.Type -> TM Subst @@ -131,25 +153,22 @@ unify sr t1 t2 = unify' (UnifyContext sr t1 t2) t1 t2  unify' :: UnifyContext -> T.Type -> T.Type -> TM Subst  unify' _   T.TInt T.TInt = return mempty -unify' ctx (T.TFun t1 t2) (T.TFun u1 u2) = (<>) <$> unify' ctx t1 u1 <*> unify' ctx t2 u2 +unify' ctx (T.TFun t1 t2) (T.TFun u1 u2) = +    (<>) <$> unify' ctx t1 u1 <*> unify' ctx t2 u2  unify' ctx (T.TTup ts) (T.TTup us)    | length ts == length us = mconcat <$> zipWithM (unify' ctx) ts us -unify' _   (T.TyVar var) ty = return (substVar var ty) -unify' _   ty (T.TyVar var) = return (substVar var ty) +unify' _   (T.TyVar T.Instantiable var) ty = return (substVar var ty) +unify' _   ty (T.TyVar T.Instantiable var) = return (substVar var ty) +-- TODO: fix unify  unify' (UnifyContext sr t1 t2) _ _ = throwError (TypeError sr t1 t2) -convertType :: S.Type -> T.Type -convertType (S.TFun t1 t2) = T.TFun (convertType t1) (convertType t2) -convertType S.TInt = T.TInt -convertType (S.TTup ts) = T.TTup (map convertType ts) -  infer :: Env -> S.Expr -> TM (Subst, T.Expr)  infer env expr = case expr of      S.Lam _ [] body -> infer env body      S.Lam sr args@(_:_:_) body -> infer env (foldr (S.Lam sr . pure) body args)      S.Lam _ [(arg, _)] body -> do          argVar <- genTyVar -        let augEnv = envAdd arg (T.TypeScheme [] argVar) env +        let augEnv = envAddDef arg (T.TypeScheme [] argVar) env          (theta, body') <- infer augEnv body          let argType = theta >>! argVar          return (theta, T.Lam (T.TFun argType (T.exprType body')) @@ -157,7 +176,7 @@ infer env expr = case expr of      S.Let _ (name, _) rhs body -> do          (theta1, rhs') <- infer env rhs          let varType = T.exprType rhs' -        let augEnv = envAdd name (T.TypeScheme [] varType) env +        let augEnv = envAddDef name (T.TypeScheme [] varType) env          (theta2, body') <- infer augEnv body          return (theta2 <> theta1, T.Let (T.Occ name varType) rhs' body')      S.Call sr func arg -> do @@ -173,14 +192,22 @@ infer env expr = case expr of      S.Int _ val -> return (mempty, T.Int val)      S.Tup _ es -> fmap T.Tup <$> inferList env es      S.Var sr name -      | Just sty <- envFind name env -> do +      | Just sty <- envFindDef name env -> do            ty <- instantiate sty            return (mempty, T.Var (T.Occ name ty))        | otherwise ->            throwError (RefError sr name) +    S.Constr sr name -> case envFindType name env of +        Just (T.TypeDef typname params typ) -> do +            restyp <- freshenFrees emptyEnv +                          (T.TNamed typname (map (T.TyVar T.Instantiable) params)) +            return (mempty, T.Constr (T.TFun typ restyp) name) +        _ -> +            throwError (RefError sr name)      S.Annot sr subex ty -> do          (theta1, subex') <- infer env subex -        theta2 <- unify sr (T.exprType subex') (convertType ty) +        ty' <- convertType (envAliases env) sr ty +        theta2 <- unify sr (T.exprType subex') ty'          return (theta2 <> theta1, theta2 >>! subex')  -- TODO: quadratic complexity  inferList :: Env -> [S.Expr] -> TM (Subst, [T.Expr]) @@ -192,23 +219,44 @@ inferList env (expr : exprs) = do  runPass :: Context -> S.Program -> Either TCError T.Program -runPass (Context _ (Builtins builtins)) prog = -    let env = Env (Map.map (generalise emptyEnv) builtins) +runPass (Context _ (Builtins builtins _)) prog = +    let env = Env (Map.map (generalise emptyEnv) builtins) mempty mempty      in runTM (typeCheck env prog)  typeCheck :: Env -> S.Program -> TM T.Program -typeCheck startEnv (S.Program decls) = -    let defs = [(name, ty) -               | S.Def (S.Function (Just ty) (name, _) _ _) <- decls] -        env = foldl (\env' (name, ty) -> envAdd name (generalise env' (convertType ty)) env') -                    startEnv defs -    in T.Program <$> mapM (typeCheckDef env . (\(S.Def def) -> def)) decls +typeCheck startEnv (S.Program decls) = do +    traceM (show decls) + +    let aliasdefs = [(n, def) +                    | S.DeclAlias def@(S.AliasDef (n, _) _ _) <- decls] +        env1 = envAddAliases (Map.fromList aliasdefs) startEnv + +    typedefs' <- checkTypedefs (envAliases env1) [def | S.DeclType def <- decls] +    let typedefsMap = Map.fromList [(n, def) | def@(T.TypeDef n _ _) <- typedefs'] + +    let funcdefs = [def | S.DeclFunc def <- decls] +    typedfuncs <- sequence +        [(name,) <$> convertType (envAliases env1) sr ty +        | S.FuncDef (Just ty) (name, sr) _ _ <- funcdefs] + +    let env2 = envAddTypes typedefsMap env1 + +    traceM (show typedefsMap) + +    let env = foldl (\env' (name, ty) -> +                        envAddDef name (generalise env' ty) env') +                    env2 typedfuncs + +    traceM (show env) + +    funcdefs' <- mapM (typeCheckFunc env) funcdefs +    return (T.Program funcdefs' typedefsMap) -typeCheckDef :: Env -> S.Def -> TM T.Def -typeCheckDef env (S.Function mannot (name, sr) args@(_:_) body) = -    typeCheckDef env (S.Function mannot (name, sr) [] (S.Lam sr args body)) -typeCheckDef env (S.Function (Just annot) (name, sr) [] body) = -    typeCheckDef env (S.Function Nothing (name, sr) [] (S.Annot sr body annot)) -typeCheckDef env (S.Function Nothing (name, _) [] body) = do +typeCheckFunc :: Env -> S.FuncDef -> TM T.Def +typeCheckFunc env (S.FuncDef mannot (name, sr) args@(_:_) body) = +    typeCheckFunc env (S.FuncDef mannot (name, sr) [] (S.Lam sr args body)) +typeCheckFunc env (S.FuncDef (Just annot) (name, sr) [] body) = +    typeCheckFunc env (S.FuncDef Nothing (name, sr) [] (S.Annot sr body annot)) +typeCheckFunc env (S.FuncDef Nothing (name, _) [] body) = do      (_, body') <- infer env body      return (T.Def name body') diff --git a/typecheck/CC/Typecheck/Typedefs.hs b/typecheck/CC/Typecheck/Typedefs.hs new file mode 100644 index 0000000..ad9bdd8 --- /dev/null +++ b/typecheck/CC/Typecheck/Typedefs.hs @@ -0,0 +1,50 @@ +module CC.Typecheck.Typedefs(checkTypedefs) where + +import Control.Monad.Except +import Data.Foldable (traverse_) +import qualified Data.Map.Strict as Map +import Data.Map.Strict (Map) +import qualified Data.Set as Set + +import qualified CC.AST.Source as S +import qualified CC.AST.Typed as T +import CC.Typecheck.Types +import CC.Types + + +checkArity :: Map Name Int -> S.TypeDef -> TM () +checkArity typeArity (S.TypeDef (_, sr) args ty) = go ty +  where +    argNames = map fst args  -- probably a small list + +    go :: S.Type -> TM () +    go (S.TFun t1 t2) = go t1 >> go t2 +    go S.TInt = return () +    go (S.TTup ts) = mapM_ go ts +    go (S.TNamed n ts) +      | Just arity <- Map.lookup n typeArity = +          if length ts == arity +              then mapM_ go ts +              else throwError (TypeArityError sr n arity (length ts)) +      | otherwise = throwError (RefError sr n) +    go (S.TUnion ts) = traverse_ go ts +    go (S.TyVar n) +      | n `elem` argNames = return () +      | otherwise = throwError (RefError sr n) + +checkTypedefs :: Map Name S.AliasDef -> [S.TypeDef] -> TM [T.TypeDef] +checkTypedefs aliases origdefs = do +    let aliasArity = Map.map (\(S.AliasDef _ args _) -> length args) aliases +        typeArity = Map.fromList [(n, length args) +                                 | S.TypeDef (n, _) args _ <- origdefs] + +    let dups = Set.fromList (map (\(S.TypeDef (n, _) _ _) -> n) origdefs) +                    Set.\\ Map.keysSet typeArity +    when (not (Set.null dups)) $ +        throwError (DupTypeError (Set.findMin dups)) + +    let aliasdefs = [S.TypeDef name args typ +                    | S.AliasDef name args typ <- Map.elems aliases] + +    mapM_ (checkArity (aliasArity <> typeArity)) (aliasdefs ++ origdefs) +    mapM (convertTypeDef aliases) origdefs diff --git a/typecheck/CC/Typecheck/Types.hs b/typecheck/CC/Typecheck/Types.hs new file mode 100644 index 0000000..3f3c471 --- /dev/null +++ b/typecheck/CC/Typecheck/Types.hs @@ -0,0 +1,103 @@ +module CC.Typecheck.Types where + +import Control.Monad.State.Strict +import Control.Monad.Except +import qualified Data.Map.Strict as Map +import Data.Map.Strict (Map) +import Data.Maybe (fromMaybe) +import qualified Data.Set as Set +import Data.Set (Set) + +import qualified CC.AST.Source as S +import qualified CC.AST.Typed as T +import CC.Pretty +import CC.Types + + +data TCError = TypeError SourceRange T.Type T.Type +             | RefError SourceRange Name +             | TypeArityError SourceRange Name Int Int +             | DupTypeError Name +  deriving (Show) + +instance Pretty TCError where +    pretty (TypeError sr real expect) = +        "Type error: Expression at " ++ pretty sr ++ +            " has type " ++ pretty real ++ +            ", but should have type " ++ pretty expect +    pretty (RefError sr name) = +        "Reference error: Variable '" ++ name ++ "' out of scope at " ++ pretty sr +    pretty (TypeArityError sr name wanted got) = +        "Type error: Type '" ++ name ++ "' has arity " ++ show wanted ++ +            " but gets " ++ show got ++ " type arguments at " ++ pretty sr +    pretty (DupTypeError name) = +        "Duplicate types: Type '" ++ name ++ "' defined multiple times" + +type TM a = ExceptT TCError (State Int) a + +genId :: TM Int +genId = state (\idval -> (idval, idval + 1)) + +genTyVar :: TM T.Type +genTyVar = T.TyVar T.Instantiable <$> genId + +runTM :: TM a -> Either TCError a +runTM m = evalState (runExceptT m) 1 + + +convertType :: Map Name S.AliasDef -> SourceRange -> S.Type -> TM T.Type +convertType aliases sr = fmap snd . convertType' aliases mempty sr + +convertTypeDef :: Map Name S.AliasDef -> S.TypeDef -> TM T.TypeDef +convertTypeDef aliases (S.TypeDef (name, sr) args ty) = do +    (mapping, ty') <- convertType' aliases (Set.fromList (map fst args)) sr ty +    let args' = [mapping Map.! n | (n, _) <- args] +    return (T.TypeDef name args' ty') + +convertType' :: Map Name S.AliasDef -> Set Name -> SourceRange -> S.Type -> TM (Map Name Int, T.Type) +convertType' aliases extraVars sr origtype = do +    rewritten <- rewrite origtype +    let frees = Set.toList (extraVars <> freeVars rewritten) +    nums <- traverse (const genId) frees +    let mapping = Map.fromList (zip frees nums) +    return (mapping, convert mapping rewritten) +  where +    rewrite :: S.Type -> TM S.Type +    rewrite (S.TFun t1 t2) = S.TFun <$> rewrite t1 <*> rewrite t2 +    rewrite S.TInt = return S.TInt +    rewrite (S.TTup ts) = S.TTup <$> mapM rewrite ts +    rewrite (S.TNamed n ts) +      | Just (S.AliasDef _ args typ) <- Map.lookup n aliases = +          if length args == length ts +              then rewrite (subst (Map.fromList (zip (map fst args) ts)) typ) +              else throwError (TypeArityError sr n (length args) (length ts)) +      | otherwise = +          S.TNamed n <$> mapM rewrite ts +    rewrite (S.TUnion ts) = S.TUnion . Set.fromList <$> mapM rewrite (Set.toList ts) +    rewrite (S.TyVar n) = return (S.TyVar n) + +    -- Substitute type variables +    subst :: Map Name S.Type -> S.Type -> S.Type +    subst mp (S.TFun t1 t2) = S.TFun (subst mp t1) (subst mp t2) +    subst _  S.TInt = S.TInt +    subst mp (S.TTup ts) = S.TTup (map (subst mp) ts) +    subst mp (S.TNamed n ts) = S.TNamed n (map (subst mp) ts) +    subst mp (S.TUnion ts) = S.TUnion (Set.map (subst mp) ts) +    subst mp orig@(S.TyVar n) = fromMaybe orig (Map.lookup n mp) + +    freeVars :: S.Type -> Set Name +    freeVars (S.TFun t1 t2) = freeVars t1 <> freeVars t2 +    freeVars S.TInt = mempty +    freeVars (S.TTup ts) = Set.unions (map freeVars ts) +    freeVars (S.TNamed _ ts) = Set.unions (map freeVars ts) +    freeVars (S.TUnion ts) = Set.unions (map freeVars (Set.toList ts)) +    freeVars (S.TyVar n) = Set.singleton n + +    convert :: Map Name Int -> S.Type -> T.Type +    convert mp (S.TFun t1 t2) = T.TFun (convert mp t1) (convert mp t2) +    convert _  S.TInt = T.TInt +    convert mp (S.TTup ts) = T.TTup (map (convert mp) ts) +    convert mp (S.TNamed n ts) = T.TNamed n (map (convert mp) ts) +    convert mp (S.TUnion ts) = T.TUnion (Set.map (convert mp) ts) +    -- TODO: Should this be Rigid? I really don't know how this works. +    convert mp (S.TyVar n) = T.TyVar T.Rigid (mp Map.! n)  | 
