module CC.Typecheck.Types where import Control.Monad.State.Strict import Control.Monad.Except import qualified Data.Map.Strict as Map import Data.Map.Strict (Map) import Data.Maybe (fromMaybe) import qualified Data.Set as Set import Data.Set (Set) import qualified CC.AST.Source as S import qualified CC.AST.Typed as T import CC.Pretty import CC.Types data TCError = UnifyError SourceRange T.Type T.Type T.Type T.Type (Maybe UnifyReason) | RefError SourceRange Name | TypeArityError SourceRange Name Int Int | DupTypeError Name deriving (Show) data UnifyReason = URNotInUnion | URAmbiguousWeakening deriving (Show) instance Pretty TCError where pretty (UnifyError sr real expect unifyt1 unifyt2 mreason) = "Type error: Expression at " ++ pretty sr ++ " has type " ++ pretty real ++ ", but should have type " ++ pretty expect ++ " (when unifying " ++ pretty unifyt1 ++ " and " ++ pretty unifyt2 ++ ")" ++ maybe "" (\r -> " (reason: " ++ pretty r ++ ")") mreason pretty (RefError sr name) = "Reference error: Variable '" ++ name ++ "' out of scope at " ++ pretty sr pretty (TypeArityError sr name wanted got) = "Type error: Type '" ++ name ++ "' has arity " ++ show wanted ++ " but gets " ++ show got ++ " type arguments at " ++ pretty sr pretty (DupTypeError name) = "Duplicate types: Type '" ++ name ++ "' defined multiple times" instance Pretty UnifyReason where pretty URNotInUnion = "type not found in union" pretty URAmbiguousWeakening = "type unifies with multiple items in the union" type TM a = ExceptT TCError (State Int) a genId :: TM Int genId = state (\idval -> (idval, idval + 1)) genTyVar :: TM T.Type genTyVar = T.TyVar T.Instantiable <$> genId runTM :: TM a -> Either TCError a runTM m = evalState (runExceptT m) 1 convertType :: Map Name S.AliasDef -> SourceRange -> S.Type -> TM T.Type convertType aliases sr = fmap snd . convertType' aliases mempty sr convertTypeDef :: Map Name S.AliasDef -> S.TypeDef -> TM T.TypeDef convertTypeDef aliases (S.TypeDef (name, sr) args ty) = do (mapping, ty') <- convertType' aliases (Set.fromList (map fst args)) sr ty let args' = [mapping Map.! n | (n, _) <- args] return (T.TypeDef name args' ty') convertType' :: Map Name S.AliasDef -> Set Name -> SourceRange -> S.Type -> TM (Map Name Int, T.Type) convertType' aliases extraVars sr origtype = do rewritten <- rewrite origtype let frees = Set.toList (extraVars <> freeVars rewritten) nums <- traverse (const genId) frees let mapping = Map.fromList (zip frees nums) return (mapping, convert mapping rewritten) where rewrite :: S.Type -> TM S.Type rewrite (S.TFun t1 t2) = S.TFun <$> rewrite t1 <*> rewrite t2 rewrite S.TInt = return S.TInt rewrite (S.TTup ts) = S.TTup <$> mapM rewrite ts rewrite (S.TNamed n ts) | Just (S.AliasDef _ args typ) <- Map.lookup n aliases = if length args == length ts then rewrite (subst (Map.fromList (zip (map fst args) ts)) typ) else throwError (TypeArityError sr n (length args) (length ts)) | otherwise = S.TNamed n <$> mapM rewrite ts rewrite (S.TUnion ts) = S.TUnion . Set.fromList <$> mapM rewrite (Set.toList ts) rewrite (S.TyVar n) = return (S.TyVar n) -- Substitute type variables subst :: Map Name S.Type -> S.Type -> S.Type subst mp (S.TFun t1 t2) = S.TFun (subst mp t1) (subst mp t2) subst _ S.TInt = S.TInt subst mp (S.TTup ts) = S.TTup (map (subst mp) ts) subst mp (S.TNamed n ts) = S.TNamed n (map (subst mp) ts) subst mp (S.TUnion ts) = S.TUnion (Set.map (subst mp) ts) subst mp orig@(S.TyVar n) = fromMaybe orig (Map.lookup n mp) freeVars :: S.Type -> Set Name freeVars (S.TFun t1 t2) = freeVars t1 <> freeVars t2 freeVars S.TInt = mempty freeVars (S.TTup ts) = Set.unions (map freeVars ts) freeVars (S.TNamed _ ts) = Set.unions (map freeVars ts) freeVars (S.TUnion ts) = Set.unions (map freeVars (Set.toList ts)) freeVars (S.TyVar n) = Set.singleton n convert :: Map Name Int -> S.Type -> T.Type convert mp (S.TFun t1 t2) = T.TFun (convert mp t1) (convert mp t2) convert _ S.TInt = T.TInt convert mp (S.TTup ts) = T.TTup (map (convert mp) ts) convert mp (S.TNamed n ts) = T.TNamed n (map (convert mp) ts) convert mp (S.TUnion ts) = T.TUnion (Set.map (convert mp) ts) convert mp (S.TyVar n) = T.TyVar T.Instantiable (mp Map.! n)