{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE EmptyDataDeriving #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE GADTs #-} module HSVIS.Typecheck ( StageTyped, typecheck, -- * Typed AST synonyms -- TProgram, TDataDef, TFunDef, TFunEq, TKind, TType, TPattern, TRHS, TExpr, ) where import Control.Monad import Data.Bifunctor (first, second) import Data.Foldable (toList) import Data.List (find) import Data.Map.Strict (Map) import Data.Maybe (fromMaybe) import Data.Monoid (Ap(..)) import qualified Data.Map.Strict as Map import Data.Set (Set) import qualified Data.Set as Set import Debug.Trace import Data.Bag import Data.List.NonEmpty.Util import HSVIS.AST import HSVIS.Parser import HSVIS.Diagnostic import HSVIS.Pretty import HSVIS.Typecheck.Solve data StageTC type instance X DataDef StageTC = () type instance X FunDef StageTC = CType type instance X FunEq StageTC = CType type instance X Kind StageTC = () type instance X Type StageTC = CKind type instance X Pattern StageTC = CType type instance X RHS StageTC = CType type instance X Expr StageTC = CType data instance E Type StageTC = TUniVar Int deriving (Show) data instance E Kind StageTC = KUniVar Int deriving (Show, Eq, Ord) data instance E TypeSig StageTC deriving (Show) type CProgram = Program StageTC type CDataDef = DataDef StageTC type CFunDef = FunDef StageTC type CFunEq = FunEq StageTC type CKind = Kind StageTC type CType = Type StageTC type CPattern = Pattern StageTC type CRHS = RHS StageTC type CExpr = Expr StageTC data StageTyped type instance X DataDef StageTyped = TType type instance X FunDef StageTyped = TType type instance X FunEq StageTyped = TType type instance X Kind StageTyped = () type instance X Type StageTyped = TKind type instance X Pattern StageTyped = TType type instance X RHS StageTyped = TType type instance X Expr StageTyped = TType data instance E Type StageTyped deriving (Show) data instance E Kind StageTyped deriving (Show) data instance E TypeSig StageTyped deriving (Show) type TProgram = Program StageTyped type TDataDef = DataDef StageTyped type TFunDef = FunDef StageTyped type TFunEq = FunEq StageTyped type TKind = Kind StageTyped type TType = Type StageTyped type TPattern = Pattern StageTyped type TRHS = RHS StageTyped type TExpr = Expr StageTyped instance Pretty (E Type StageTC) where prettysPrec _ (TUniVar n) = showString ("?t" ++ show n) instance Pretty (E Kind StageTC) where prettysPrec _ (KUniVar n) = showString ("?k" ++ show n) typecheck :: FilePath -> String -> PProgram -> ([Diagnostic], TProgram) typecheck fp source prog = let (ds1, cs, _, _, progtc) = runTCM (tcProgram prog) (fp, source) 1 (Env mempty mempty) (ds2, subK, subT) = solveConstrs cs in (toList (ds1 <> ds2), doneProg subK subT progtc) data Constr -- Equality constraints: "left" must be equal to "right" because of the thing -- at the given range. "left" is the expected thing; "right" is the observed -- thing. = CEq CType CType Range | CEqK CKind CKind Range deriving (Show) data Env = Env (Map Name CKind) (Map Name CType) deriving (Show) newtype TCM a = TCM { runTCM :: (FilePath, String) -- ^ reader context: file and file contents -> Int -- ^ state: next id to generate -> Env -- ^ state: type and value environment -> (Bag Diagnostic -- ^ writer: diagnostics ,Bag Constr -- ^ writer: constraints ,Int, Env, a) } instance Functor TCM where fmap f (TCM g) = TCM $ \ctx i env -> let (ds, cs, i', env', x) = g ctx i env in (ds, cs, i', env', f x) instance Applicative TCM where pure x = TCM $ \_ i env -> (mempty, mempty, i, env, x) (<*>) = ap instance Monad TCM where TCM f >>= g = TCM $ \ctx i1 env1 -> let (ds2, cs2, i2, env2, x) = f ctx i1 env1 (ds3, cs3, i3, env3, y) = runTCM (g x) ctx i2 env2 in (ds2 <> ds3, cs2 <> cs3, i3, env3, y) raise :: Severity -> Range -> String -> TCM () raise sev rng@(Range (Pos y _) _) msg = TCM $ \(fp, source) i env -> (pure (Diagnostic sev fp rng [] (lines source !! y) msg), mempty, i, env, ()) emit :: Constr -> TCM () emit c = TCM $ \_ i env -> (mempty, pure c, i, env, ()) collectConstraints :: (Constr -> Maybe b) -> TCM a -> TCM (Bag b, a) collectConstraints predicate (TCM f) = TCM $ \ctx i env -> let (ds, cs, i', env', x) = f ctx i env (yes, no) = bagPartition predicate cs in (ds, no, i', env', (yes, x)) getFullEnv :: TCM Env getFullEnv = TCM $ \_ i env -> (mempty, mempty, i, env, env) putFullEnv :: Env -> TCM () putFullEnv env = TCM $ \_ i _ -> (mempty, mempty, i, env, ()) genId :: TCM Int genId = TCM $ \_ i env -> (mempty, mempty, i + 1, env, i) getKind :: Name -> TCM (Maybe CKind) getKind name = do Env tenv _ <- getFullEnv return (Map.lookup name tenv) getType :: Name -> TCM (Maybe CType) getType name = do Env _ venv <- getFullEnv return (Map.lookup name venv) modifyTEnv :: (Map Name CKind -> Map Name CKind) -> TCM () modifyTEnv f = do Env tenv venv <- getFullEnv putFullEnv (Env (f tenv) venv) modifyVEnv :: (Map Name CType -> Map Name CType) -> TCM () modifyVEnv f = do Env tenv venv <- getFullEnv putFullEnv (Env tenv (f venv)) scopeTEnv :: TCM a -> TCM a scopeTEnv m = do Env origtenv _ <- getFullEnv res <- m modifyTEnv (\_ -> origtenv) return res scopeVEnv :: TCM a -> TCM a scopeVEnv m = do Env _ origvenv <- getFullEnv res <- m modifyVEnv (\_ -> origvenv) return res genKUniVar :: TCM CKind genKUniVar = KExt () . KUniVar <$> genId genUniVar :: CKind -> TCM CType genUniVar k = TExt k . TUniVar <$> genId getKind' :: Range -> Name -> TCM CKind getKind' rng name = getKind name >>= \case Nothing -> do raise SError rng $ "Type not in scope: " ++ pretty name genKUniVar Just k -> return k getType' :: Range -> Name -> TCM CType getType' rng name = getType name >>= \case Nothing -> do raise SError rng $ "Variable not in scope: " ++ pretty name genUniVar (KType ()) Just k -> return k tcProgram :: PProgram -> TCM CProgram tcProgram (Program ddefs1 fdefs1) = do (kconstrs, ddefs2) <- collectConstraints isCEqK $ do mapM_ prepareDataDef ddefs1 mapM tcDataDef ddefs1 kinduvars <- solveKindVars kconstrs let ddefs3 = map (substDdef kinduvars mempty) ddefs2 traceM (unlines (map pretty ddefs3)) fdefs2 <- mapM tcFunDef fdefs1 return (Program ddefs3 fdefs2) prepareDataDef :: PDataDef -> TCM () prepareDataDef (DataDef _ name params _) = do parkinds <- mapM (\_ -> genKUniVar) params let k = foldr (KFun ()) (KType ()) parkinds modifyTEnv (Map.insert name k) -- Assumes that the kind of the name itself has already been registered with -- the correct arity (this is done by prepareDataDef). tcDataDef :: PDataDef -> TCM CDataDef tcDataDef (DataDef rng name params cons) = do kd <- getKind' rng name let (pkinds, kret) = splitKind kd -- sanity checking; would be nicer to store these in prepareDataDef already when (length pkinds /= length params) $ error "tcDataDef: Invalid param kind list length" case kret of Right () -> return () _ -> error "tcDataDef: Invalid ret kind" cons' <- scopeTEnv $ do modifyTEnv (Map.fromList (zip (map snd params) pkinds) <>) mapM (\(cname, fieldtys) -> (cname,) <$> mapM (kcType (Just (KType ()))) fieldtys) cons return (DataDef () name (zip pkinds (map snd params)) cons') promoteDown :: Maybe CKind -> TCM CKind promoteDown Nothing = genKUniVar promoteDown (Just k) = return k downEqK :: Range -> Maybe CKind -> CKind -> TCM () downEqK _ Nothing _ = return () downEqK rng (Just k1) k2 = emit $ CEqK k1 k2 rng -- | Given (maybe) the expected kind of this type, and a type, check it for -- kind-correctness. kcType :: Maybe CKind -> PType -> TCM CType kcType mdown = \case TApp rng t ts -> do t' <- kcType Nothing t ts' <- mapM (kcType Nothing) ts retk <- promoteDown mdown let expected = foldr (KFun ()) retk (map extOf ts') emit $ CEqK (extOf t') expected rng return (TApp retk t' ts') TTup rng ts -> do ts' <- mapM (kcType (Just (KType ()))) ts forM_ (zip (map extOf ts) ts') $ \(trng, ct) -> emit $ CEqK (extOf ct) (KType ()) trng downEqK rng mdown (KType ()) return (TTup (KType ()) ts') TList rng t -> do t' <- kcType (Just (KType ())) t emit $ CEqK (extOf t') (KType ()) (extOf t) downEqK rng mdown (KType ()) return (TList (KType ()) t') TFun rng t1 t2 -> do t1' <- kcType (Just (KType ())) t1 t2' <- kcType (Just (KType ())) t2 emit $ CEqK (extOf t1') (KType ()) (extOf t1) emit $ CEqK (extOf t2') (KType ()) (extOf t2) downEqK rng mdown (KType ()) return (TFun (KType ()) t1' t2') TCon rng n -> do k <- getKind' rng n downEqK rng mdown k return (TCon k n) TVar rng n -> do k <- getKind' rng n downEqK rng mdown k return (TVar k n) tcFunDef :: PFunDef -> TCM CFunDef tcFunDef (FunDef rng name msig eqs) = do when (not $ allEq (fmap (length . funeqPats) eqs)) $ raise SError rng "Function equations have differing numbers of arguments" typ <- case msig of TypeSig sig -> kcType (Just (KType ())) sig TypeSigExt NoTypeSig -> genUniVar (KType ()) eqs' <- mapM (tcFunEq typ) eqs return (FunDef typ name (TypeSig typ) eqs') tcFunEq :: CType -> PFunEq -> TCM CFunEq tcFunEq down (FunEq rng name pats rhs) = error "tcFunEq" solveKindVars :: Bag (CKind, CKind, Range) -> TCM (Map Int CKind) solveKindVars cs = do let (asg, errs) = solveConstraints reduce (foldMap pure . kindUniVars) substKind (\case KExt () (KUniVar v) -> Just v _ -> Nothing) kindSize (toList cs) forM_ errs $ \case UEUnequal k1 k2 rng -> raise SError rng $ "Kind mismatch:\n\ \- " ++ pretty k1 ++ "\n\ \- " ++ pretty k2 UERecursive uvar k rng -> raise SError rng $ "Kind cannot be recursive: " ++ pretty (KExt () (KUniVar uvar)) ++ " = " ++ pretty k -- default unconstrained kind variables to Type let unconstrKUVars = foldMap kindUniVars (Map.elems asg) Set.\\ Map.keysSet asg defaults = Map.fromList (map (,KType ()) (toList unconstrKUVars)) asg' = Map.map (substKind defaults) asg <> defaults return asg' where reduce :: CKind -> CKind -> Range -> (Bag (Int, CKind, Range), Bag (CKind, CKind, Range)) reduce lhs rhs rng = case (lhs, rhs) of -- unification variables produce constraints on a unification variable (KExt () (KUniVar i), KExt () (KUniVar j)) | i == j -> mempty (KExt () (KUniVar i), k ) -> (pure (i, k, rng), mempty) (k , KExt () (KUniVar i)) -> (pure (i, k, rng), mempty) -- if lhs and rhs have equal prefixes, recurse (KType () , KType () ) -> mempty (KFun () a b, KFun () c d) -> reduce a c rng <> reduce b d rng -- otherwise, this is a kind mismatch (k1, k2) -> (mempty, pure (k1, k2, rng)) kindSize :: CKind -> Int kindSize KType{} = 1 kindSize (KFun () a b) = 1 + kindSize a + kindSize b kindSize (KExt () KUniVar{}) = 2 solveConstrs :: Bag Constr -> (Bag Diagnostic, Map Int TKind, Map Int TType) solveConstrs = error "solveConstrs" substProg :: Map Int CKind -- ^ Kind variable instantiations -> Map Int CType -- ^ Type variable instantiations -> CProgram -> CProgram substProg = error "substProg" substDdef :: Map Int CKind -> Map Int CType -> CDataDef -> CDataDef substDdef mk mt (DataDef () name pars cons) = DataDef () name (map (first (substKind mk)) pars) (map (second (map (substType mk mt))) cons) substType :: Map Int CKind -> Map Int CType -> CType -> CType substType mk mt = \case TApp k t ts -> TApp (substKind mk k) (substType mk mt t) (map (substType mk mt) ts) TTup k ts -> TTup (substKind mk k) (map (substType mk mt) ts) TList k t -> TList (substKind mk k) (substType mk mt t) TFun k t1 t2 -> TFun (substKind mk k) (substType mk mt t1) (substType mk mt t2) TCon k n -> TCon (substKind mk k) n TVar k n -> TVar (substKind mk k) n t@(TExt _ (TUniVar v)) -> fromMaybe t (Map.lookup v mt) substKind :: Map Int CKind -> CKind -> CKind substKind m = \case KType () -> KType () KFun () k1 k2 -> KFun () (substKind m k1) (substKind m k2) k@(KExt () (KUniVar v)) -> fromMaybe k (Map.lookup v m) doneProg :: Map Int TKind -- ^ Kind variable instantiations -> Map Int TType -- ^ Type variable instantiations -> CProgram -> TProgram doneProg = error "doneProg" kindUniVars :: CKind -> Set Int kindUniVars = \case KType{} -> mempty KFun () a b -> kindUniVars a <> kindUniVars b KExt () (KUniVar v) -> Set.singleton v allEq :: (Eq a, Foldable t) => t a -> Bool allEq l = case toList l of [] -> True x : xs -> all (== x) xs funeqPats :: FunEq t -> [Pattern t] funeqPats (FunEq _ _ pats _) = pats splitKind :: Kind s -> ([Kind s], Either (E Kind s) (X Kind s)) splitKind (KType x) = ([], Right x) splitKind (KFun _ k1 k2) = first (k1:) (splitKind k2) splitKind (KExt _ e) = ([], Left e) isCEqK :: Constr -> Maybe (CKind, CKind, Range) isCEqK (CEqK k1 k2 rng) = Just (k1, k2, rng) isCEqK _ = Nothing foldMapM :: (Applicative f, Monoid m, Foldable t) => (a -> f m) -> t a -> f m foldMapM f = getAp . foldMap (Ap . f)