diff options
Diffstat (limited to 'typecheck/CC/Typecheck.hs')
-rw-r--r-- | typecheck/CC/Typecheck.hs | 81 |
1 files changed, 67 insertions, 14 deletions
diff --git a/typecheck/CC/Typecheck.hs b/typecheck/CC/Typecheck.hs index 292eaeb..35ffdfe 100644 --- a/typecheck/CC/Typecheck.hs +++ b/typecheck/CC/Typecheck.hs @@ -7,7 +7,7 @@ import Control.Monad.State.Strict import Control.Monad.Except import qualified Data.Map.Strict as Map import Data.Map.Strict (Map) -import Data.Maybe +import Data.Maybe (fromMaybe, catMaybes) import qualified Data.Set as Set import Data.Set (Set) @@ -133,18 +133,18 @@ envAliases (Env _ _ aliases) = aliases substVar :: Int -> T.Type -> Subst substVar var ty = Subst (Map.singleton var ty) +freshenScheme :: T.TypeScheme -> TM T.TypeScheme +freshenScheme (T.TypeScheme bnds ty) = do + vars <- traverse (const genId) bnds + let theta = Subst (Map.fromList (zip bnds (map (T.TyVar T.Instantiable) vars))) + return (T.TypeScheme vars (theta >>! ty)) + generalise :: Env -> T.Type -> T.TypeScheme generalise env ty = T.TypeScheme (Set.toList (freeInstTypeVars ty Set.\\ freeInstTypeVars env)) ty instantiate :: T.TypeScheme -> TM T.Type -instantiate (T.TypeScheme bnds ty) = do - vars <- traverse (const genTyVar) bnds - let theta = Subst (Map.fromList (zip bnds vars)) - return (theta >>! ty) - -freshenFrees :: Env -> T.Type -> TM T.Type -freshenFrees env = instantiate . generalise env +instantiate scheme = (\(T.TypeScheme _ ty') -> ty') <$> freshenScheme scheme replaceRigid :: T.Type -> T.Type replaceRigid (T.TFun t1 t2) = T.TFun (replaceRigid t1) (replaceRigid t2) @@ -154,21 +154,67 @@ replaceRigid (T.TNamed n ts) = T.TNamed n (map replaceRigid ts) replaceRigid (T.TUnion ts) = T.TUnion (Set.map replaceRigid ts) replaceRigid (T.TyVar _ v) = T.TyVar T.Rigid v +checkType :: Env -> SourceRange -> T.Type -> TM () +checkType env sr (T.TFun t1 t2) = checkType env sr t1 >> checkType env sr t2 +checkType _ _ T.TInt = return () +checkType env sr (T.TTup ts) = mapM_ (checkType env sr) ts +checkType env sr (T.TNamed n ts) = do + mapM_ (checkType env sr) ts + case envFindType n env of + Just (T.TypeDef _ args _) + | length ts == length args -> return () + | otherwise -> throwError (TypeArityError sr n (length args) (length ts)) + Nothing -> throwError (RefError sr n) +checkType env sr (T.TUnion ts) = mapM_ (checkType env sr) (Set.toList ts) +checkType _ _ (T.TyVar _ _) = return () + data UnifyContext = UnifyContext SourceRange T.Type T.Type +-- t1 = got type: what did we infer using existing information +-- t2 = wanted type: what should the type equal due to an annotation or language usage +-- Unions are only weakened towards t2: {a} U {a,b} works, but {a,b} U {a} is an error. unify :: SourceRange -> T.Type -> T.Type -> TM Subst unify sr t1 t2 = unify' (UnifyContext sr t1 t2) t1 t2 unify' :: UnifyContext -> T.Type -> T.Type -> TM Subst unify' _ T.TInt T.TInt = return mempty unify' ctx (T.TFun t1 t2) (T.TFun u1 u2) = - (<>) <$> unify' ctx t1 u1 <*> unify' ctx t2 u2 + -- This one is subtle: function arguments are contravariant, so we swap + -- unification direction here. + (<>) <$> unify' ctx t2 u2 <*> unify' ctx u1 t1 unify' ctx (T.TTup ts) (T.TTup us) | length ts == length us = mconcat <$> zipWithM (unify' ctx) ts us unify' _ (T.TyVar T.Instantiable var) ty = return (substVar var ty) unify' _ ty (T.TyVar T.Instantiable var) = return (substVar var ty) --- TODO: fix unify -unify' (UnifyContext sr t1 t2) _ _ = throwError (TypeError sr t1 t2) +unify' ctx (T.TNamed n1 ts) (T.TNamed n2 us) + | n1 == n2, length ts == length us = mconcat <$> zipWithM (unify' ctx) ts us +unify' ctx (T.TUnion topts) (T.TUnion topus) = + -- TODO: this is quadratic in the right union size. I'm not sure whether + -- this is avoidable, but it can probably be improved by partitioning on + -- the name of TNamed's. + mconcat . snd <$> mapAccumLM (\us ty -> matchup ty us) topus (Set.toList topts) + where + -- If a match is found, returns the substitution and the rest of the RHS + -- types; else, throws an error + matchup :: T.Type -> Set T.Type -> TM (Set T.Type, Subst) + matchup ty ts = do + let splits = [(item, uncurry (<>) (Set.split item ts)) | item <- Set.toList ts] + results <- forM splits $ \(item, rest) -> + catchError ((Just . (rest,)) <$> unify' ctx ty item) + (const (return Nothing)) + case catMaybes results of + [] -> let UnifyContext sr topt1 topt2 = ctx + in throwError (UnifyError sr topt1 topt2 ty (T.TUnion ts) + (Just URNotInUnion)) + [result] -> return result + _ -> let UnifyContext sr topt1 topt2 = ctx + in throwError (UnifyError sr topt1 topt2 ty (T.TUnion ts) + (Just URAmbiguousWeakening)) +unify' ctx ty (T.TUnion us) = unify' ctx (T.TUnion (Set.singleton ty)) (T.TUnion us) +unify' ctx (T.TUnion ts) ty + | Set.size ts == 0 = return mempty + | Set.size ts == 1 = unify' ctx (Set.findMin ts) ty +unify' (UnifyContext sr t1 t2) u1 u2 = throwError (UnifyError sr t1 t2 u1 u2 Nothing) infer :: Env -> S.Expr -> TM (Subst, T.Expr) infer env expr = case expr of @@ -207,14 +253,15 @@ infer env expr = case expr of throwError (RefError sr name) S.Constr sr name -> case envFindType name env of Just (T.TypeDef typname params typ) -> do - restyp <- freshenFrees emptyEnv - (T.TNamed typname (map (T.TyVar T.Instantiable) params)) - return (mempty, T.Constr (T.TFun typ restyp) name) + T.TypeScheme params' typ' <- freshenScheme (T.TypeScheme params typ) + let restyp = T.TNamed typname (map (T.TyVar T.Instantiable) params') + return (mempty, T.Constr (T.TFun typ' restyp) name) _ -> throwError (RefError sr name) S.Annot sr subex ty -> do (theta1, subex') <- infer env subex ty' <- convertType (envAliases env) sr ty + checkType env sr ty' -- Make sure the type of the subexpression matches the type with rigid -- variables, then make it instantiable variables instead for the rest -- of the code. @@ -272,3 +319,9 @@ typeCheckFunc env (S.FuncDef (Just annot) (name, sr) [] body) = typeCheckFunc env (S.FuncDef Nothing (name, _) [] body) = do (_, body') <- infer env body return (T.Def name body') + + +mapAccumLM :: Monad m => (a -> b -> m (a, c)) -> a -> [b] -> m (a, [c]) +mapAccumLM _ start [] = return (start, []) +mapAccumLM f start (x:xs) = + f start x >>= \(next, y) -> fmap (y :) <$> mapAccumLM f next xs |