{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE EmptyDataDeriving #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TupleSections #-} module HSVIS.Typecheck where import Control.Monad import Data.Bifunctor (first, second) import Data.Foldable (toList) import Data.Map.Strict (Map) import Data.Monoid (First(..)) import qualified Data.Map.Strict as Map import Data.Bag import Data.List.NonEmpty.Util import HSVIS.AST import HSVIS.Parser import HSVIS.Diagnostic import HSVIS.Pretty data StageTC type instance X DataDef StageTC = () type instance X FunDef StageTC = CType type instance X FunEq StageTC = CType type instance X Kind StageTC = () type instance X Type StageTC = CKind type instance X Pattern StageTC = CType type instance X RHS StageTC = CType type instance X Expr StageTC = CType data instance E Type StageTC = TUniVar Int deriving (Show) data instance E Kind StageTC = KUniVar Int deriving (Show) type CProgram = Program StageTC type CDataDef = DataDef StageTC type CFunDef = FunDef StageTC type CFunEq = FunEq StageTC type CKind = Kind StageTC type CType = Type StageTC type CPattern = Pattern StageTC type CRHS = RHS StageTC type CExpr = Expr StageTC data StageTyped type instance X DataDef StageTyped = TType type instance X FunDef StageTyped = TType type instance X FunEq StageTyped = TType type instance X Kind StageTyped = () type instance X Type StageTyped = TKind type instance X Pattern StageTyped = TType type instance X RHS StageTyped = TType type instance X Expr StageTyped = TType data instance E Type StageTyped deriving (Show) data instance E Kind StageTyped deriving (Show) type TProgram = Program StageTyped type TDataDef = DataDef StageTyped type TFunDef = FunDef StageTyped type TFunEq = FunEq StageTyped type TKind = Kind StageTyped type TType = Type StageTyped type TPattern = Pattern StageTyped type TRHS = RHS StageTyped type TExpr = Expr StageTyped typecheck :: FilePath -> String -> PProgram -> ([Diagnostic], Program TType) typecheck fp source prog = let (ds1, cs, _, _, progtc) = runTCM (tcProgram prog) (fp, source) 1 (Env mempty mempty) (ds2, sub) = solveConstrs cs in (toList (ds1 <> ds2), substProg sub progtc) data Constr -- Equality constraints: "left" must be equal to "right" because of the thing -- at the given range. "left" is the expected thing; "right" is the observed -- thing. = CEq CType CType Range | CEqK CKind CKind Range deriving (Show) data Env = Env (Map Name CKind) (Map Name CType) deriving (Show) newtype TCM a = TCM { runTCM :: (FilePath, String) -- ^ reader context: file and file contents -> Int -- ^ state: next id to generate -> Env -- ^ state: type and value environment -> (Bag Diagnostic -- ^ writer: diagnostics ,Bag Constr -- ^ writer: constraints ,Int, Env, a) } instance Functor TCM where fmap f (TCM g) = TCM $ \ctx i env -> let (ds, cs, i', env', x) = g ctx i env in (ds, cs, i', env', f x) instance Applicative TCM where pure x = TCM $ \_ i env -> (mempty, mempty, i, env, x) (<*>) = ap instance Monad TCM where TCM f >>= g = TCM $ \ctx i1 env1 -> let (ds2, cs2, i2, env2, x) = f ctx i1 env1 (ds3, cs3, i3, env3, y) = runTCM (g x) ctx i2 env2 in (ds2 <> ds3, cs2 <> cs3, i3, env3, y) raise :: Range -> String -> TCM () raise rng@(Range (Pos y _) _) msg = TCM $ \(fp, source) i env -> (pure (Diagnostic fp rng [] (lines source !! y) msg), mempty, i, env, ()) emit :: Constr -> TCM () emit c = TCM $ \_ i env -> (mempty, pure c, i, env, ()) collectConstraints :: (Constr -> Maybe b) -> TCM a -> TCM (Bag b, a) collectConstraints predicate (TCM f) = TCM $ \ctx i env -> let (ds, cs, i', env', x) = f ctx i env (yes, no) = bagPartition predicate cs in (ds, no, i', env', (yes, x)) getFullEnv :: TCM Env getFullEnv = TCM $ \_ i env -> (mempty, mempty, i, env, env) putFullEnv :: Env -> TCM () putFullEnv env = TCM $ \_ i _ -> (mempty, mempty, i, env, ()) genId :: TCM Int genId = TCM $ \_ i env -> (mempty, mempty, i, env, i) getKind :: Name -> TCM (Maybe CKind) getKind name = do Env tenv _ <- getFullEnv return (Map.lookup name tenv) getType :: Name -> TCM (Maybe CType) getType name = do Env _ venv <- getFullEnv return (Map.lookup name venv) modifyTEnv :: (Map Name CKind -> Map Name CKind) -> TCM () modifyTEnv f = do Env tenv venv <- getFullEnv putFullEnv (Env (f tenv) venv) modifyVEnv :: (Map Name CType -> Map Name CType) -> TCM () modifyVEnv f = do Env tenv venv <- getFullEnv putFullEnv (Env tenv (f venv)) scopeTEnv :: TCM a -> TCM a scopeTEnv m = do Env origtenv _ <- getFullEnv res <- m modifyTEnv (\_ -> origtenv) return res scopeVEnv :: TCM a -> TCM a scopeVEnv m = do Env _ origvenv <- getFullEnv res <- m modifyVEnv (\_ -> origvenv) return res genKUniVar :: TCM CKind genKUniVar = KExt () . KUniVar <$> genId genUniVar :: CKind -> TCM CType genUniVar k = TExt k . TUniVar <$> genId getKind' :: Range -> Name -> TCM CKind getKind' rng name = getKind name >>= \case Nothing -> do raise rng $ "Type not in scope: " ++ pretty name genKUniVar Just k -> return k getType' :: Range -> Name -> TCM CType getType' rng name = getType name >>= \case Nothing -> do raise rng $ "Variable not in scope: " ++ pretty name genUniVar (KType ()) Just k -> return k tcProgram :: PProgram -> TCM CProgram tcProgram (Program ddefs fdefs) = do (kconstrs, ddefs') <- collectConstraints isCEqK $ do mapM_ prepareDataDef ddefs mapM tcDataDef ddefs solveKindVars kconstrs fdefs' <- mapM tcFunDef fdefs return (Program ddefs' fdefs') prepareDataDef :: PDataDef -> TCM () prepareDataDef (DataDef _ name params _) = do parkinds <- mapM (\_ -> genKUniVar) params let k = foldr (KFun ()) (KType ()) parkinds modifyTEnv (Map.insert name k) -- Assumes that the kind of the name itself has already been registered with -- the correct arity (this is doen by prepareDataDef). tcDataDef :: PDataDef -> TCM CDataDef tcDataDef (DataDef rng name params cons) = do kd <- getKind' rng name let (pkinds, kret) = splitKind kd -- sanity checking; would be nicer to store these in prepareDataDef already when (length pkinds /= length params) $ error "tcDataDef: Invalid param kind list length" case kret of Right () -> return () _ -> error "tcDataDef: Invalid ret kind" cons' <- scopeTEnv $ do modifyTEnv (Map.fromList (zip (map snd params) pkinds) <>) mapM (\(cname, ty) -> (cname,) <$> mapM kcType ty) cons return (DataDef () name (zip pkinds (map snd params)) cons') kcType :: PType -> TCM CType kcType = \case TApp rng t ts -> do t' <- kcType t ts' <- mapM kcType ts retk <- genKUniVar let expected = foldr (KFun ()) retk (map extOf ts') emit $ CEqK (extOf t') expected rng return (TApp retk t' ts') TTup _ ts -> do ts' <- mapM kcType ts forM_ (zip (map extOf ts) ts') $ \(trng, ct) -> emit $ CEqK (extOf ct) (KType ()) trng return (TTup (KType ()) ts') TList _ t -> do t' <- kcType t emit $ CEqK (extOf t') (KType ()) (extOf t) return (TList (KType ()) t') TFun _ t1 t2 -> do t1' <- kcType t1 t2' <- kcType t2 emit $ CEqK (extOf t1') (KType ()) (extOf t1) emit $ CEqK (extOf t2') (KType ()) (extOf t2) return (TFun (KType ()) t1' t2') TCon rng n -> TCon <$> getKind' rng n <*> pure n TVar rng n -> TVar <$> getKind' rng n <*> pure n tcFunDef :: PFunDef -> TCM CFunDef tcFunDef (FunDef _ name msig eqs) = do when (not $ allEq (fmap (length . funeqPats) eqs)) $ raise (sconcatne (fmap extOf eqs)) "Function equations have differing numbers of arguments" typ <- case msig of TypeSig sig -> kcType sig TypeSigExt NoTypeSig -> genUniVar (KType ()) eqs' <- mapM (tcFunEq typ) eqs return (FunDef typ name (TypeSig typ) eqs') tcFunEq :: CType -> PFunEq -> TCM CFunEq tcFunEq = _ solveKindVars :: Bag (CKind, CKind, Range) -> TCM () solveKindVars = mapM_ $ \(a, b, rng) -> do let (subst, First merr) = reduce a b forM_ merr $ \(erra, errb) -> raise rng $ "Kind mismatch:\n\ \- Expected: " ++ pretty a ++ "\n\ \- Observed: " ++ pretty b ++ "\n\ \because '" ++ pretty erra ++ "' and '" ++ pretty errb ++ "' don't match" let collected :: [(Int, Bag CKind)] collected = Map.assocs $ Map.fromListWith (<>) (fmap (second pure) (toList subst)) _ where reduce :: CKind -> CKind -> (Bag (Int, CKind), First (CKind, CKind)) reduce (KType ()) (KType ()) = mempty reduce (KFun () a b) (KFun () c d) = reduce a c <> reduce b d reduce (KExt () (KUniVar i)) k = (pure (i, k), mempty) reduce k (KExt () (KUniVar i)) = (pure (i, k), mempty) reduce k1 k2 = (mempty, pure (k1, k2)) allEq :: (Eq a, Foldable t) => t a -> Bool allEq l = case toList l of [] -> True x : xs -> all (== x) xs funeqPats :: FunEq t -> [Pattern t] funeqPats (FunEq _ _ pats _) = pats splitKind :: Kind s -> ([Kind s], Either (E Kind s) (X Kind s)) splitKind (KType x) = ([], Right x) splitKind (KFun _ k1 k2) = first (k1:) (splitKind k2) splitKind (KExt _ e) = ([], Left e) isCEqK :: Constr -> Maybe (CKind, CKind, Range) isCEqK (CEqK k1 k2 rng) = Just (k1, k2, rng) isCEqK _ = Nothing