{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE EmptyDataDeriving #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeFamilies #-} module HSVIS.Typecheck ( StageTyped, typecheck, -- * Typed AST synonyms TProgram, TDataDef, TDataField, TFunDef, TFunEq, TKind, TType, TPattern, TRHS, TExpr, ) where import Control.Monad import Data.Bifunctor (first, second, bimap) import Data.Foldable (toList) import Data.List (inits, foldl1') import Data.Map.Strict (Map) import Data.Maybe (fromMaybe) import qualified Data.Map.Strict as Map import Data.Tuple (swap) import Data.Semigroup (First(..)) import Data.Set (Set) import qualified Data.Set as Set import Debug.Trace import Data.Bag import Data.Map.Monoidal (MMap(..)) import qualified Data.Map.Monoidal as MMap import HSVIS.AST import HSVIS.Parser import HSVIS.Diagnostic import HSVIS.Pretty import HSVIS.Typecheck.Solve data StageTC type instance X DataDef StageTC = (Range, CKind) type instance X DataField StageTC = Range type instance X FunDef StageTC = Range type instance X TypeSig StageTC = Maybe Range type instance X FunEq StageTC = () type instance X Kind StageTC = () type instance X Type StageTC = (Range, CKind) type instance X Pattern StageTC = (Range, CType) type instance X RHS StageTC = CType type instance X Expr StageTC = (Range, CType) data instance E Type StageTC = TUniVar Int | TForallC (Name, CKind) CType deriving (Show, Eq, Ord) data instance E Kind StageTC = KUniVar Int deriving (Show, Eq, Ord) data instance E TypeSig StageTC deriving (Show) type CProgram = Program StageTC type CDataDef = DataDef StageTC type CDataField = DataField StageTC type CFunDef = FunDef StageTC type CTypeSig = TypeSig StageTC type CFunEq = FunEq StageTC type CKind = Kind StageTC type CType = Type StageTC type CPattern = Pattern StageTC type CRHS = RHS StageTC type CExpr = Expr StageTC data StageTyped type instance X DataDef StageTyped = TKind type instance X DataField StageTyped = Range type instance X FunDef StageTyped = () type instance X TypeSig StageTyped = () type instance X FunEq StageTyped = () type instance X Kind StageTyped = () type instance X Type StageTyped = TKind type instance X Pattern StageTyped = TType type instance X RHS StageTyped = TType type instance X Expr StageTyped = TType data instance E Type StageTyped = TForall (Name, TKind) TType deriving (Show) data instance E Kind StageTyped deriving (Show) data instance E TypeSig StageTyped deriving (Show) type TProgram = Program StageTyped type TDataDef = DataDef StageTyped type TDataField = DataField StageTyped type TFunDef = FunDef StageTyped type TFunEq = FunEq StageTyped type TKind = Kind StageTyped type TType = Type StageTyped type TPattern = Pattern StageTyped type TRHS = RHS StageTyped type TExpr = Expr StageTyped instance Pretty (E Type StageTC) where prettysPrec _ (TUniVar n) = showString ("?t" ++ show n) prettysPrec d (TForallC (n, k) t) = showParen (d > 0) $ showString "forall (" . prettys n . showString " :: " . prettys k . showString "). " . prettys t instance Pretty (E Kind StageTC) where prettysPrec _ (KUniVar n) = showString ("?k" ++ show n) typecheck :: FilePath -> String -> PProgram -> ([Diagnostic], TProgram) typecheck fp source prog = let (ds, cs, _, _, resprog) = runTCM (tcTop prog) (fp, source) 1 (Env mempty mempty mempty) in trace ("[tc] resprog = " ++ show resprog) $ if not (null cs) then error $ "Constraints left after typechecker completion: " ++ show cs else (toList ds, resprog) data Constr -- Equality constraints: "left" must be equal to "right" because of the thing -- at the given range. "left" is the expected thing; "right" is the observed -- thing. = CEq CType CType Range | CEqK CKind CKind Range deriving (Show) data Env = Env (Map Name CKind) -- ^ types in scope (including variables) (Map Name CType) -- ^ values in scope (constructors and variables) (Map Name InvCon) -- ^ patterns in scope (inverse constructors) deriving (Show) data InvCon = InvCon (Map Name CKind) -- ^ universally quantified type variables CType -- ^ input type of the inverse constructor (result of the constructor) [CType] -- ^ field types deriving (Show) newtype TCM a = TCM { runTCM :: (FilePath, String) -- ^ reader context: file and file contents -> Int -- ^ state: next id to generate -> Env -- ^ state: type and value environment -> (Bag Diagnostic -- ^ writer: diagnostics ,Bag Constr -- ^ writer: constraints ,Int, Env, a) } instance Functor TCM where fmap f (TCM g) = TCM $ \ctx i env -> let !(ds, cs, i', env', !x) = g ctx i env !res = f x in (ds, cs, i', env', res) instance Applicative TCM where pure x = TCM $ \_ i env -> x `seq` (mempty, mempty, i, env, x) (<*>) = ap instance Monad TCM where TCM f >>= g = TCM $ \ctx i1 env1 -> let !(ds2, cs2, i2, env2, !x) = f ctx i1 env1 !(ds3, cs3, i3, env3, !y) = runTCM (g x) ctx i2 env2 in (ds2 <> ds3, cs2 <> cs3, i3, env3, y) class Monad m => MonadRaise m where raise :: Severity -> Range -> String -> m () instance MonadRaise TCM where raise sev rng@(Range (Pos y _) _) msg = TCM $ \(fp, source) i env -> (pure (Diagnostic sev fp rng [] (lines source !! y) msg), mempty, i, env, ()) emit :: Constr -> TCM () emit c = TCM $ \_ i env -> (mempty, pure c, i, env, ()) collectConstraints :: (Constr -> Maybe b) -> TCM a -> TCM (Bag b, a) collectConstraints predicate (TCM f) = TCM $ \ctx i env -> let (ds, cs, i', env', x) = f ctx i env (yes, no) = bagPartition predicate cs in (ds, no, i', env', (yes, x)) getFullEnv :: TCM Env getFullEnv = TCM $ \_ i env -> (mempty, mempty, i, env, env) putFullEnv :: Env -> TCM () putFullEnv env = TCM $ \_ i _ -> (mempty, mempty, i, env, ()) genId :: TCM Int genId = TCM $ \_ i env -> (mempty, mempty, i + 1, env, i) getKind :: Name -> TCM (Maybe CKind) getKind name = do Env tenv _ _ <- getFullEnv return (Map.lookup name tenv) getType :: Name -> TCM (Maybe CType) getType name = do Env _ venv _ <- getFullEnv return (Map.lookup name venv) getInvCon :: Name -> TCM (Maybe InvCon) getInvCon name = do Env _ _ icenv <- getFullEnv return (Map.lookup name icenv) modifyTEnv :: (Map Name CKind -> Map Name CKind) -> TCM () modifyTEnv f = do Env tenv venv icenv <- getFullEnv putFullEnv (Env (f tenv) venv icenv) modifyVEnv :: (Map Name CType -> Map Name CType) -> TCM () modifyVEnv f = do Env tenv venv icenv <- getFullEnv putFullEnv (Env tenv (f venv) icenv) modifyICEnv :: (Map Name InvCon -> Map Name InvCon) -> TCM () modifyICEnv f = do Env tenv venv icenv <- getFullEnv putFullEnv (Env tenv venv (f icenv)) scopeTEnv :: TCM a -> TCM a scopeTEnv m = do Env origtenv _ _ <- getFullEnv res <- m modifyTEnv (\_ -> origtenv) return res scopeVEnv :: TCM a -> TCM a scopeVEnv m = do Env _ origvenv _ <- getFullEnv res <- m modifyVEnv (\_ -> origvenv) return res genKUniVar :: TCM CKind genKUniVar = KExt () . KUniVar <$> genId genUniVar :: Range -> CKind -> TCM CType genUniVar rng k = TExt (rng, k) . TUniVar <$> genId getKind' :: Range -> Name -> TCM CKind getKind' rng name = getKind name >>= \case Nothing -> do raise SError rng $ "Type not in scope: " ++ pretty name k <- genKUniVar -- insert it now so that all occurrences of this out-of-scope name get the same kind modifyTEnv (Map.insert name k) return k Just k -> return k getType' :: Range -> Name -> TCM CType getType' rng name = getType name >>= \case Nothing -> do raise SError rng $ "Variable not in scope: " ++ pretty name t <- genUniVar rng (KType ()) -- insert it now so that all occurrences of this out-of-scope name get the same type modifyVEnv (Map.insert name t) return t Just k -> return k tcTop :: PProgram -> TCM TProgram tcTop prog = do (cs, prog') <- collectConstraints Just (tcProgram prog) (subK, subT) <- solveConstrs cs finaliseProg (substProg subK subT prog') tcProgram :: PProgram -> TCM CProgram tcProgram (Program ddefs1 fdefs1) = do -- kind-check data definitions and collect ensuing kind constraints (kconstrs, ddefs2) <- collectConstraints isCEqK $ do ks <- mapM prepareDataDef ddefs1 zipWithM kcDataDef ks ddefs1 -- solve the kind constraints and finalise data types kinduvars <- solveKindVars kconstrs let ddefs3 = map (substDdef kinduvars mempty) ddefs2 modifyTEnv (Map.map (substKind kinduvars)) -- generate constructor values and inverse constructors for all data types forM_ ddefs3 $ \ddef -> do modifyVEnv (Map.fromList (generateConstructors ddef) <>) modifyICEnv (Map.fromList (generateInvCons ddef) <>) traceM (unlines (map pretty ddefs3)) -- check the function definitions fdefs2 <- tcFunDefBlock fdefs1 return (Program ddefs3 fdefs2) -- Bring data type name in scope with a kind of the specified arity prepareDataDef :: PDataDef -> TCM (CKind, [CKind]) prepareDataDef (DataDef _ name params _) = do parkinds <- mapM (\_ -> genKUniVar) params let k = foldr (KFun ()) (KType ()) parkinds modifyTEnv (Map.insert name k) return (k, parkinds) -- Assumes that the kind of the name itself has already been registered with -- the correct arity (this is done by prepareDataDef). kcDataDef :: (CKind, [CKind]) -> PDataDef -> TCM CDataDef kcDataDef (kd, parkinds) (DataDef defrng name params cons) = do -- ensure unicity of type param names params' <- let prenames = Set.fromList (map snd params) namegen = filter (`Set.notMember` prenames) [Name ('t' : show i) | i <- [1::Int ..]] in forM (zip3 params (inits (map snd params)) namegen) $ \((rng, pname), previous, replname) -> if pname `elem` previous then do raise SError rng $ "Duplicate type parameter: " ++ pretty pname return replname else return pname -- kind-check the constructors cons' <- scopeTEnv $ do modifyTEnv (Map.fromList (zip params' parkinds) <>) forM cons $ \(cname, fields) -> do fields' <- forM fields $ \(DataField () t) -> DataField (extOf t) <$> kcType KCTMNormal (Just (KType ())) t return (cname, fields') let params_ranges = map fst params return (DataDef (defrng, kd) name (zip (zip params_ranges parkinds) params') cons') generateInvCons :: CDataDef -> [(Name, InvCon)] generateInvCons (DataDef (defrng, k) tname params cons) = let tyvars = Map.fromList [(n, k1) | ((_rng, k1), n) <- params] -- defrng is a bit imprecise here, but it's fine resty = TApp (defrng, KType ()) (TCon (defrng, k) tname) (map (uncurry TVar) params) in [(cname, InvCon tyvars resty (map dataFieldType fields)) | (cname, fields) <- cons] generateConstructors :: CDataDef -> [(Name, CType)] generateConstructors (DataDef (defrng, k) tname params cons) = -- using defrng is again a bit imprecise here let resty = TApp (defrng, KType ()) (TCon (defrng, k) tname) (map (uncurry TVar) params) in [let funty = foldr (TFun (defrng, KType ())) resty (map dataFieldType fields) in (cname, foldr (\((_, k1), n1) -> TExt (defrng, KType ()) . TForallC (n1, k1)) funty params) | (cname, fields) <- cons] promoteDownK :: Maybe CKind -> TCM CKind promoteDownK Nothing = genKUniVar promoteDownK (Just k) = return k downEqK :: Range -> Maybe CKind -> CKind -> TCM () downEqK _ Nothing _ = return () downEqK rng (Just k1) k2 = emit $ CEqK k1 k2 rng data KCTypeMode ext ret where -- | Kind-check a normal type: out-of-scope type variables are reported as errors. KCTMNormal :: KCTypeMode () CType -- | Kind-check an open type: out-of-scope type variables are returned. This -- is used to check function type signatures, which may have an implicit -- forall telescope at the head. KCTMOpen :: KCTypeMode (MMap Name (First CKind)) (CType, Map Name CKind) -- | Given (maybe) the expected kind of this type, and a type, check it for -- kind-correctness. kcType :: forall ext ret. KCTypeMode ext ret -> Maybe CKind -> PType -> TCM ret kcType KCTMNormal mdown t = snd <$> kcType' KCTMNormal mdown t kcType KCTMOpen mdown t = second (\(MMap m) -> Map.map getFirst m) . swap <$> kcType' KCTMOpen mdown t -- | Given (maybe) the expected kind of this type, and a type, check it for -- kind-correctness. kcType' :: forall ext ret. Monoid ext => KCTypeMode ext ret -> Maybe CKind -> PType -> TCM (ext, CType) kcType' mode mdown = \case TApp rng t ts -> do (ext1, t') <- kcType' mode Nothing t (ext2, ts') <- sequence <$> mapM (kcType' mode Nothing) ts -- TODO: give more useful down kinds retk <- promoteDownK mdown let expected = foldr (KFun ()) retk (map (snd . extOf) ts') emit $ CEqK (snd (extOf t')) expected rng return (ext1 <> ext2, TApp (rng, retk) t' ts') TTup rng ts -> do (ext, ts') <- sequence <$> mapM (kcType' mode (Just (KType ()))) ts downEqK rng mdown (KType ()) return (ext, TTup (rng, KType ()) ts') TList rng t -> do (ext, t') <- kcType' mode (Just (KType ())) t downEqK rng mdown (KType ()) return (ext, TList (rng, KType ()) t') TFun rng t1 t2 -> do (ext1, t1') <- kcType' mode (Just (KType ())) t1 (ext2, t2') <- kcType' mode (Just (KType ())) t2 downEqK rng mdown (KType ()) return (ext1 <> ext2, TFun (rng, KType ()) t1' t2') TCon rng n -> do k <- getKind' rng n downEqK rng mdown k return (mempty, TCon (rng, k) n) TVar rng n -> do mk <- getKind n case mk of Nothing -> do k <- promoteDownK mdown case mode of KCTMNormal -> do raise SError rng $ "Type variable out of scope: " ++ pretty n return ((), TVar (rng, k) n) KCTMOpen -> return (MMap.singleton n (pure k), TVar (rng, k) n) Just k -> do downEqK rng mdown k -- TODO: need to instantiate top-level foralls in k here return (mempty, TVar (rng, k) n) TExt rng (TForallP n mk1 t) -> do -- implicit forall k1 <- maybe genKUniVar (return . checkKind) mk1 k2 <- genKUniVar downEqK rng mdown k2 (ext, t') <- scopeTEnv $ do modifyTEnv (Map.insert n k1) kcType' mode (Just k2) t return (ext, TExt (rng, k2) (TForallC (n, k1) t')) -- 'k2', not 'k1 -> k2', because the forall is implicit tcFunDefBlock :: [PFunDef] -> TCM [CFunDef] tcFunDefBlock fdefs = do let funrange = foldl1' (<>) (map extOf fdefs) -- collect types for each of the bound functions, or unification variables if there's no type signature bound <- mapM (\(FunDef _ n ts _) -> (n,) <$> tcTypeSig funrange ts) fdefs defs' <- scopeVEnv $ do modifyVEnv (Map.fromList [(n, t) | (n, TypeSig _ t) <- bound] <>) forM (zip fdefs bound) $ \(def, (_, sig)) -> tcFunDef sig def -- -- take the actual found types for typechecking the body (and link them -- -- to the variables generated above) -- let bound2 = map (\(FunDef _ n (TypeSig _ ty) _) -> (n, ty)) defs' -- forM_ (zip3 fdefs bound bound2) $ \(fdef, (_, tvar), (_, ty)) -> -- emit $ CEq ty tvar (extOf fdef) -- which is expected/observed? which range? /shrug/ return defs' -- | Just the identity, because there's nothing to be checked in our simple kinds checkKind :: PKind -> CKind checkKind (KType ()) = KType () checkKind (KFun () k1 k2) = KFun () (checkKind k1) (checkKind k2) -- | The 'Range' argument is the range of the entire function; used in case there -- is no type signature. tcTypeSig :: Range -> PTypeSig -> TCM CTypeSig tcTypeSig _ (TypeSig () sig) = do (typ, freetvars) <- kcType KCTMOpen (Just (KType ())) sig return $ TypeSig (Just (extOf sig)) $ foldr (\(n, k) -> TExt (extOf sig, KType ()) . TForallC (n, k)) typ (Map.assocs freetvars) tcTypeSig funrange (TypeSigExt () NoTypeSig) = TypeSig Nothing <$> genUniVar funrange (KType ()) -- | Typechecks the function definition, but assumes its signature has already -- been checked, and is passed separately. Thus the PTypeSig in the PFunDef is -- ignored. tcFunDef :: CTypeSig -> PFunDef -> TCM CFunDef tcFunDef typesig@(TypeSig _ funtyp) (FunDef rng name _ eqs) = do when (not $ allEq (fmap (length . funeqPats) eqs)) $ raise SError rng "Function equations have differing numbers of arguments" eqs' <- scopeVEnv $ mapM (tcFunEq funtyp) eqs return (FunDef rng name typesig eqs') tcFunEq :: CType -> PFunEq -> TCM CFunEq tcFunEq down (FunEq rng name pats rhs) = do -- getFullEnv >>= \env -> traceM $ "[tcFunEq] Env = " ++ show env (argtys, rhsty) <- unfoldFunTy rng (length pats) down scopeVEnv $ do pats' <- zipWithM tcPattern argtys pats rhs' <- tcRHS rhsty rhs return (FunEq () name pats' rhs') -- | Brings the bound variables in scope tcPattern :: CType -> PPattern -> TCM CPattern tcPattern down = \case PWildcard rng -> return $ PWildcard (rng, down) PVar rng n -> modifyVEnv (Map.insert n down) >> return (PVar (rng, down) n) PAs rng n p -> do modifyVEnv (Map.insert n down) p' <- tcPattern down p return $ PAs (rng, snd (extOf p')) n p' PCon rng n ps -> getInvCon n >>= \case Just (InvCon tyvars match fields) -> do unisub <- mapM (genUniVar rng) tyvars -- substitution for the universally quantified variables let match' = substType mempty mempty unisub match fields' = map (substType mempty mempty unisub) fields emit $ CEq down match' rng PCon (rng, match') n <$> zipWithM tcPattern fields' ps Nothing -> do raise SError rng $ "Constructor not in scope: " ++ pretty n return (PWildcard (rng, down)) POp rng p1 op p2 -> case op of OCons -> do eltty <- genUniVar rng (KType ()) let listty = TList (rng, KType ()) eltty emit $ CEq down listty rng p1' <- tcPattern eltty p1 p2' <- tcPattern listty p2 return (POp (rng, listty) p1' OCons p2') _ -> do raise SError rng $ "Operator is not a constructor: " ++ pretty op return (PWildcard (rng, down)) PList rng ps -> do eltty <- genUniVar rng (KType ()) let listty = TList (rng, KType ()) eltty emit $ CEq down listty rng PList (rng, listty) <$> mapM (tcPattern eltty) ps PTup rng ps -> do ts <- mapM (\p -> genUniVar (extOf p) (KType ())) ps let typ = TTup (rng, KType ()) ts emit $ CEq down typ rng PTup (rng, typ) <$> zipWithM tcPattern ts ps tcRHS :: CType -> PRHS -> TCM CRHS tcRHS _ (Guarded _ _) = error "typecheck: Guards not yet supported" tcRHS down (Plain _ e) = do e' <- tcExpr down e return $ Plain (snd (extOf e')) e' tcExpr :: CType -> PExpr -> TCM CExpr tcExpr down = \case ELit rng lit -> do let ty = case lit of LInt{} -> TCon (rng, KType ()) (Name "Int") LFloat{} -> TCon (rng, KType ()) (Name "Double") LChar{} -> TCon (rng, KType ()) (Name "Char") LString{} -> TList (rng, KType ()) (TCon (rng, KType ()) (Name "Char")) emit $ CEq down ty rng return $ ELit (rng, ty) lit EVar rng n -> do ty <- getType' rng n ty' <- instantiateTForallsUni rng ty emit $ CEq down ty' rng return $ EVar (rng, ty') n ECon rng n -> do ty <- getType' rng n ty' <- instantiateTForallsUni rng ty emit $ CEq down ty' rng return $ EVar (rng, ty') n EList rng es -> do eltty <- genUniVar rng (KType ()) let listty = TList (rng, KType ()) eltty emit $ CEq down listty rng EList (rng, listty) <$> mapM (tcExpr listty) es ETup rng es -> do ts <- mapM (\_ -> genUniVar rng (KType ())) es let typ = TTup (rng, KType ()) ts emit $ CEq down typ rng ETup (rng, typ) <$> zipWithM tcExpr ts es EApp rng e1 es -> do argtys <- mapM (\e -> genUniVar (extOf e) (KType ())) es let funty = foldr (TFun (rng, KType ())) down argtys EApp (rng, funty) <$> tcExpr funty e1 <*> zipWithM tcExpr argtys es -- TODO: these types are way too monomorphic and in any case these -- ~operators~ functions should not be built-in EOp rng e1 op e2 -> do let int = \r -> TCon (r, KType ()) (Name "Int") bool = \r -> TCon (r, KType ()) (Name "Bool") rng1 = extOf e1 rng2 = extOf e2 (rty, aty1, aty2) <- case op of OAdd -> return (int rng, int rng1, int rng2) OSub -> return (int rng, int rng1, int rng2) OMul -> return (int rng, int rng1, int rng2) ODiv -> return (int rng, int rng1, int rng2) OMod -> return (int rng, int rng1, int rng2) OEqu -> return (bool rng, int rng1, int rng2) OPow -> return (int rng, int rng1, int rng2) OCons -> do eltty <- genUniVar rng1 (KType ()) let listty = TList (rng2, KType ()) eltty return (listty, eltty, listty) emit $ CEq down rty rng e1' <- tcExpr aty1 e1 e2' <- tcExpr aty2 e2 return (EOp (rng, rty) e1' op e2') EIf rng e1 e2 e3 -> do e1' <- tcExpr (TCon (rng, KType ()) (Name "Bool")) e1 e2' <- tcExpr down e2 e3' <- tcExpr down e3 return (EIf (rng, down) e1' e2' e3') ECase rng e1 alts -> do ty <- genUniVar rng (KType ()) e1' <- tcExpr ty e1 alts' <- forM alts $ \(pat, rhs) -> scopeVEnv $ (,) <$> tcPattern ty pat <*> tcRHS down rhs return $ ECase (rng, down) e1' alts' ELet rng defs body -> do -- TODO: need to properly scope constraints here defs' <- tcFunDefBlock defs let bound2 = map (\(FunDef _ n (TypeSig _ ty) _) -> (n, ty)) defs' scopeVEnv $ do modifyVEnv (Map.fromList bound2 <>) body' <- tcExpr down body return $ ELet (rng, down) defs' body' EError rng -> return $ EError (rng, down) -- | The 'Range' is the place where the type is being instantiated. instantiateTForallsUni :: Range -> CType -> TCM CType instantiateTForallsUni instRng = go mempty where go sub (TExt _ (TForallC (n, k1) t)) = do var <- genUniVar instRng k1 go (Map.insert n var sub) t go sub t = return $ substType mempty mempty sub t unfoldFunTy :: Range -> Int -> CType -> TCM ([CType], CType) unfoldFunTy _ n t | n <= 0 = return ([], t) unfoldFunTy rng n (TFun _ t1 t2) = do (params, core) <- unfoldFunTy rng (n - 1) t2 return (t1 : params, core) unfoldFunTy rng n t = do vars <- replicateM n (genUniVar rng (KType ())) core <- genUniVar rng (KType ()) let expected = foldr (TFun (rng, KType ())) core vars emit $ CEq expected t rng return (vars, core) solveConstrs :: MonadRaise m => Bag Constr -> m (Map Int CKind, Map Int CType) solveConstrs constrs = do let (tcs, kcs) = partitionConstrs constrs subK <- solveKindVars kcs subT <- solveTypeVars tcs let subT' = Map.map (substType subK mempty mempty) subT return (subK, subT') solveKindVars :: MonadRaise m => Bag (CKind, CKind, Range) -> m (Map Int CKind) solveKindVars cs = do let (asg, errs) = solveConstraints reduce (foldMap pure . kindUniVars) substKind (\case KExt () (KUniVar v) -> Just v _ -> Nothing) kindSize (toList cs) forM_ errs $ \case UEUnequal k1 k2 rng -> raise SError rng $ "Kind mismatch:\n\ \- " ++ pretty k1 ++ "\n\ \- " ++ pretty k2 UERecursive uvar k rng -> raise SError rng $ "Kind cannot be recursive: " ++ pretty (KExt () (KUniVar uvar)) ++ " = " ++ pretty k -- default unconstrained kind variables to Type let unconstrKUVars = foldMap kindUniVars (Map.elems asg) Set.\\ Map.keysSet asg defaults = Map.fromList (map (,KType ()) (toList unconstrKUVars)) asg' = Map.map (substKind defaults) asg <> defaults return asg' where reduce :: CKind -> CKind -> Range -> (Bag (Int, CKind, Range), Bag (CKind, CKind, Range)) reduce lhs rhs rng = case (lhs, rhs) of -- unification variables produce constraints on a unification variable (KExt () (KUniVar i), KExt () (KUniVar j)) | i == j -> mempty (KExt () (KUniVar i), k ) -> (pure (i, k, rng), mempty) (k , KExt () (KUniVar i)) -> (pure (i, k, rng), mempty) -- if lhs and rhs have equal prefixes, recurse (KType () , KType () ) -> mempty (KFun () a b, KFun () c d) -> reduce a c rng <> reduce b d rng -- otherwise, this is a kind mismatch (k1, k2) -> (mempty, pure (k1, k2, rng)) kindSize :: CKind -> Int kindSize KType{} = 1 kindSize (KFun () a b) = 1 + kindSize a + kindSize b kindSize (KExt () KUniVar{}) = 2 solveTypeVars :: MonadRaise m => Bag (CType, CType, Range) -> m (Map Int CType) solveTypeVars cs = do let (asg, errs) = solveConstraints reduce (foldMap pure . typeUniVars) (\m -> substType mempty m mempty) (\case TExt _ (TUniVar v) -> Just v _ -> Nothing) typeSize (toList cs) forM_ errs $ \case UEUnequal t1 t2 rng -> raise SError rng $ "Type mismatch:\n\ \- " ++ pretty t1 ++ "\n\ \- " ++ pretty t2 UERecursive uvar t rng -> raise SError rng $ "Type cannot be recursive: " ++ pretty (TExt (extOf t) (TUniVar uvar)) ++ " = " ++ pretty t return asg where reduce :: CType -> CType -> Range -> (Bag (Int, CType, Range), Bag (CType, CType, Range)) reduce lhs rhs rng = case (lhs, rhs) of -- unification variables produce constraints on a unification variable (TExt _ (TUniVar i), TExt _ (TUniVar j)) | i == j -> mempty (TExt _ (TUniVar i), t ) -> (pure (i, t, rng), mempty) (t , TExt _ (TUniVar i)) -> (pure (i, t, rng), mempty) -- if lhs and rhs have equal prefixes, recurse (TApp _ t ts, TApp _ t' ts') -> reduce t t' rng <> foldMap (\(a, b) -> reduce a b rng) (zip ts ts') (TTup _ ts, TTup _ ts') -> foldMap (\(a, b) -> reduce a b rng) (zip ts ts') (TList _ t, TList _ t') -> reduce t t' rng (TFun _ t1 t2, TFun _ t1' t2') -> reduce t1 t1' rng <> reduce t2 t2' rng (TCon _ n1, TCon _ n2) | n1 == n2 -> mempty (TVar _ n1, TVar _ n2) | n1 == n2 -> mempty -- TODO: this doesn't check that the types are kind-correct. Did we already check that? (TExt (rng1, _) (TForallC (n1, k) t1), TExt _ (TForallC (n2, _) t2)) -> reduce t1 (substType mempty mempty (Map.singleton n2 (TVar (rng1, k) n1)) t2) rng -- otherwise, this is a kind mismatch (k1, k2) -> (mempty, pure (k1, k2, rng)) typeSize :: CType -> Int typeSize (TApp _ t ts) = typeSize t + sum (map typeSize ts) typeSize (TTup _ ts) = sum (map typeSize ts) typeSize (TList _ t) = 1 + typeSize t typeSize (TFun _ t1 t2) = typeSize t1 + typeSize t2 typeSize (TCon _ _) = 1 typeSize (TVar _ _) = 1 typeSize (TExt _ TUniVar{}) = 2 typeSize (TExt _ (TForallC _ t)) = 1 + typeSize t partitionConstrs :: Foldable t => t Constr -> (Bag (CType, CType, Range), Bag (CKind, CKind, Range)) partitionConstrs = foldMap $ \case CEq t1 t2 r -> (pure (t1, t2, r), mempty) CEqK k1 k2 r -> (mempty, pure (k1, k2, r)) -------------------- SUBSTITUTION FUNCTIONS -------------------- -- These take some of: -- - an instantiation map for kind unification variables (Map Int CKind) -- - an instantiation map for type unification variables (Map Int CType) -- - an instantiation map for type variables (Map Name CType) substProg :: Map Int CKind -> Map Int CType -> CProgram -> CProgram substProg mk mt (Program ds fs) = Program (map (substDdef mk mt) ds) (map (substFdef mk mt) fs) substDdef :: Map Int CKind -> Map Int CType -> CDataDef -> CDataDef substDdef mk mt (DataDef (defrng, k) name pars cons) = DataDef (defrng, substKind mk k) name (map (first (second (substKind mk))) pars) (map (second (map (substDataField mk mt))) cons) substDataField :: Map Int CKind -> Map Int CType -> CDataField -> CDataField substDataField mk mt (DataField rng t) = DataField rng (substType mk mt mempty t) substFdef :: Map Int CKind -> Map Int CType -> CFunDef -> CFunDef substFdef mk mt (FunDef rng n (TypeSig sigrng sig) eqs) = FunDef rng n (TypeSig sigrng (substType mk mt mempty sig)) (fmap (substFunEq mk mt) eqs) substFunEq :: Map Int CKind -> Map Int CType -> CFunEq -> CFunEq substFunEq mk mt (FunEq () n ps rhs) = FunEq () n (map (substPattern mk mt) ps) (substRHS mk mt rhs) substRHS :: Map Int CKind -> Map Int CType -> CRHS -> CRHS substRHS _ _ (Guarded _ _) = error "typecheck: guards unsupported" substRHS mk mt (Plain t e) = Plain (substType mk mt mempty t) (substExpr mk mt e) substPattern :: Map Int CKind -> Map Int CType -> CPattern -> CPattern substPattern mk mt = go where go (PWildcard e) = PWildcard (goExt e) go (PVar e n) = PVar (goExt e) n go (PAs e n p) = PAs (goExt e) n (go p) go (PCon e n ps) = PCon (goExt e) n (map go ps) go (POp e p1 op p2) = POp (goExt e) (go p1) op (go p2) go (PList e ps) = PList (goExt e) (map go ps) go (PTup e ps) = PTup (goExt e) (map go ps) goExt (rng, t) = (rng, substType mk mt mempty t) substExpr :: Map Int CKind -> Map Int CType -> CExpr -> CExpr substExpr mk mt = go where go (ELit e lit) = ELit (goExt e) lit go (EVar e n) = EVar (goExt e) n go (ECon e n) = ECon (goExt e) n go (EList e es) = EList (goExt e) (map go es) go (ETup e es) = ETup (goExt e) (map go es) go (EApp e e1 es) = EApp (goExt e) (go e1) (map go es) go (EOp e e1 op e2) = EOp (goExt e) (go e1) op (go e2) go (EIf e e1 e2 e3) = EIf (goExt e) (go e1) (go e2) (go e3) go (ECase e e1 alts) = ECase (goExt e) (go e1) (map (bimap (substPattern mk mt) (substRHS mk mt)) alts) go (ELet e defs body) = ELet (goExt e) (map (substFdef mk mt) defs) (go body) go (EError e) = EError (goExt e) goExt (rng, t) = (rng, substType mk mt mempty t) substType :: Map Int CKind -> Map Int CType -> Map Name CType -> CType -> CType substType mk mt mtv = go where go (TApp (rng, k) t ts) = TApp (rng, goKind k) (go t) (map go ts) go (TTup (rng, k) ts) = TTup (rng, goKind k) (map go ts) go (TList (rng, k) t) = TList (rng, goKind k) (go t) go (TFun (rng, k) t1 t2) = TFun (rng, goKind k) (go t1) (go t2) go (TCon (rng, k) n) = TCon (rng, goKind k) n go (TVar (rng, k) n) = fromMaybe (TVar (rng, goKind k) n) (Map.lookup n mtv) go (TExt (rng, k) (TUniVar v)) = fromMaybe (TExt (rng, goKind k) (TUniVar v)) (Map.lookup v mt) go (TExt (rng, k) (TForallC (n, k1) t)) = TExt (rng, goKind k) (TForallC (n, goKind k1) (go t)) goKind = substKind mk substKind :: Map Int CKind -> CKind -> CKind substKind m = \case KType () -> KType () KFun () k1 k2 -> KFun () (substKind m k1) (substKind m k2) k@(KExt () (KUniVar v)) -> fromMaybe k (Map.lookup v m) -------------------- FINALISATION FUNCTIONS -------------------- -- These report free type unification variables. finaliseProg :: MonadRaise m => CProgram -> m TProgram finaliseProg (Program ds fs) = Program <$> traverse finaliseDdef ds <*> traverse finaliseFdef fs finaliseDdef :: MonadRaise m => CDataDef -> m TDataDef finaliseDdef (DataDef (rng, k) name pars cons) = DataDef <$> finaliseKind rng k <*> pure name <*> traverse (firstM (finaliseKind rng . snd)) pars <*> traverse (secondM (traverse finaliseDataField)) cons finaliseDataField :: MonadRaise m => CDataField -> m TDataField finaliseDataField (DataField rng t) = DataField rng <$> finaliseType t finaliseFdef :: MonadRaise m => CFunDef -> m TFunDef finaliseFdef (FunDef _ n (TypeSig _ sig) eqs) = FunDef () n <$> (TypeSig () <$> finaliseType sig) <*> traverse finaliseFunEq eqs finaliseFunEq :: MonadRaise m => CFunEq -> m TFunEq finaliseFunEq (FunEq () n ps rhs) = FunEq () n <$> traverse finalisePattern ps <*> finaliseRHS rhs finaliseRHS :: MonadRaise m => CRHS -> m TRHS finaliseRHS (Guarded _ _) = error "typecheck: guards unsupported" finaliseRHS (Plain t e) = Plain <$> finaliseType t <*> finaliseExpr e finalisePattern :: MonadRaise m => CPattern -> m TPattern finalisePattern = \case PWildcard e -> PWildcard <$> goExt e PVar e n -> PVar <$> goExt e <*> pure n PAs e n p -> PAs <$> goExt e <*> pure n <*> finalisePattern p PCon e n ps -> PCon <$> goExt e <*> pure n <*> traverse finalisePattern ps POp e p1 op p2 -> POp <$> goExt e <*> finalisePattern p1 <*> pure op <*> finalisePattern p2 PList e ps -> PList <$> goExt e <*> traverse finalisePattern ps PTup e ps -> PTup <$> goExt e <*> traverse finalisePattern ps where goExt (_, t) = finaliseType t finaliseExpr :: MonadRaise m => CExpr -> m TExpr finaliseExpr = \case ELit e lit -> ELit <$> goExt e <*> pure lit EVar e n -> EVar <$> goExt e <*> pure n ECon e n -> ECon <$> goExt e <*> pure n EList e es -> EList <$> goExt e <*> traverse finaliseExpr es ETup e es -> ETup <$> goExt e <*> traverse finaliseExpr es EApp e e1 es -> EApp <$> goExt e <*> finaliseExpr e1 <*> traverse finaliseExpr es EOp e e1 op e2 -> EOp <$> goExt e <*> finaliseExpr e1 <*> pure op <*> finaliseExpr e2 EIf e e1 e2 e3 -> EIf <$> goExt e <*> finaliseExpr e1 <*> finaliseExpr e2 <*> finaliseExpr e3 ECase e e1 alts -> ECase <$> goExt e <*> finaliseExpr e1 <*> traverse (bimapM finalisePattern finaliseRHS) alts ELet e defs body -> ELet <$> goExt e <*> traverse finaliseFdef defs <*> finaliseExpr body EError e -> EError <$> goExt e where goExt (_, t) = finaliseType t finaliseType :: MonadRaise m => CType -> m TType finaliseType toptype = go toptype where go :: MonadRaise m => CType -> m TType go (TApp e t ts) = TApp <$> goExt e <*> go t <*> traverse go ts go (TTup e ts) = TTup <$> goExt e <*> traverse go ts go (TList e t) = TList <$> goExt e <*> go t go (TFun e t1 t2) = TFun <$> goExt e <*> go t1 <*> go t2 go (TCon e n) = TCon <$> goExt e <*> pure n go (TVar e n) = TVar <$> goExt e <*> pure n go t@(TExt (rng, k) TUniVar{}) = do raise SError rng $ "Ambiguous type unification variable " ++ pretty t ++ " in type: " ++ pretty toptype TVar <$> finaliseKind rng k <*> pure (Name "$_error") go (TExt e@(rng, _) (TForallC (n, k1) t)) = TExt <$> goExt e <*> (TForall <$> ((n,) <$> finaliseKind rng k1) <*> go t) goExt (rng, t) = finaliseKind rng t finaliseKind :: MonadRaise m => Range -> CKind -> m TKind finaliseKind toprng topkind = go topkind where go :: MonadRaise m => CKind -> m TKind go (KType ()) = pure (KType ()) go (KFun () k1 k2) = KFun () <$> go k1 <*> go k2 go k@(KExt () KUniVar{}) = do raise SError toprng $ "Ambiguous kind unification variable " ++ pretty k ++ " in kind: " ++ pretty topkind pure $ KType () -------------------- END OF FINALISATION FUNCTIONS -------------------- typeUniVars :: CType -> Set Int typeUniVars = \case TApp _ t ts -> typeUniVars t <> foldMap typeUniVars ts TTup _ ts -> foldMap typeUniVars ts TList _ t -> typeUniVars t TFun _ t1 t2 -> typeUniVars t1 <> typeUniVars t2 TCon _ _ -> mempty TVar _ _ -> mempty TExt _ (TUniVar v) -> Set.singleton v TExt _ (TForallC _ t) -> typeUniVars t kindUniVars :: CKind -> Set Int kindUniVars = \case KType{} -> mempty KFun () a b -> kindUniVars a <> kindUniVars b KExt () (KUniVar v) -> Set.singleton v allEq :: (Eq a, Foldable t) => t a -> Bool allEq l = case toList l of [] -> True x : xs -> all (== x) xs funeqPats :: FunEq s -> [Pattern s] funeqPats (FunEq _ _ pats _) = pats dataFieldType :: DataField s -> Type s dataFieldType (DataField _ t) = t isCEqK :: Constr -> Maybe (CKind, CKind, Range) isCEqK (CEqK k1 k2 rng) = Just (k1, k2, rng) isCEqK _ = Nothing firstM :: Functor f => (a -> f b) -> (a, c) -> f (b, c) firstM f (x, y) = (,y) <$> f x secondM :: Functor f => (b -> f c) -> (a, b) -> f (a, c) secondM f (x, y) = (x,) <$> f y bimapM :: Applicative f => (a -> f b) -> (c -> f d) -> (a, c) -> f (b, d) bimapM f g (x, y) = (,) <$> f x <*> g y