{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE EmptyDataDeriving #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TupleSections #-} module HSVIS.Typecheck where import Control.Monad import Data.Bifunctor (first) import Data.Foldable (toList) import Data.Map.Strict (Map) import Data.Maybe (fromMaybe) import Data.Monoid (Ap(..)) import qualified Data.Map.Strict as Map import Data.Set (Set) import qualified Data.Set as Set import Debug.Trace import Data.Bag import Data.List.NonEmpty.Util import HSVIS.AST import HSVIS.Parser import HSVIS.Diagnostic import HSVIS.Pretty import HSVIS.Typecheck.Solve data StageTC type instance X DataDef StageTC = () type instance X FunDef StageTC = CType type instance X FunEq StageTC = CType type instance X Kind StageTC = () type instance X Type StageTC = CKind type instance X Pattern StageTC = CType type instance X RHS StageTC = CType type instance X Expr StageTC = CType data instance E Type StageTC = TUniVar Int deriving (Show) data instance E Kind StageTC = KUniVar Int deriving (Show, Eq, Ord) data instance E TypeSig StageTC deriving (Show) type CProgram = Program StageTC type CDataDef = DataDef StageTC type CFunDef = FunDef StageTC type CFunEq = FunEq StageTC type CKind = Kind StageTC type CType = Type StageTC type CPattern = Pattern StageTC type CRHS = RHS StageTC type CExpr = Expr StageTC data StageTyped type instance X DataDef StageTyped = TType type instance X FunDef StageTyped = TType type instance X FunEq StageTyped = TType type instance X Kind StageTyped = () type instance X Type StageTyped = TKind type instance X Pattern StageTyped = TType type instance X RHS StageTyped = TType type instance X Expr StageTyped = TType data instance E Type StageTyped deriving (Show) data instance E Kind StageTyped deriving (Show) data instance E TypeSig StageTyped deriving (Show) type TProgram = Program StageTyped type TDataDef = DataDef StageTyped type TFunDef = FunDef StageTyped type TFunEq = FunEq StageTyped type TKind = Kind StageTyped type TType = Type StageTyped type TPattern = Pattern StageTyped type TRHS = RHS StageTyped type TExpr = Expr StageTyped instance Pretty (E Kind StageTC) where prettysPrec _ (KUniVar n) = showString ("?k" ++ show n) typecheck :: FilePath -> String -> PProgram -> ([Diagnostic], TProgram) typecheck fp source prog = let (ds1, cs, _, _, progtc) = runTCM (tcProgram prog) (fp, source) 1 (Env mempty mempty) (ds2, sub) = solveConstrs cs in (toList (ds1 <> ds2), substProg sub progtc) data Constr -- Equality constraints: "left" must be equal to "right" because of the thing -- at the given range. "left" is the expected thing; "right" is the observed -- thing. = CEq CType CType Range | CEqK CKind CKind Range deriving (Show) data Env = Env (Map Name CKind) (Map Name CType) deriving (Show) newtype TCM a = TCM { runTCM :: (FilePath, String) -- ^ reader context: file and file contents -> Int -- ^ state: next id to generate -> Env -- ^ state: type and value environment -> (Bag Diagnostic -- ^ writer: diagnostics ,Bag Constr -- ^ writer: constraints ,Int, Env, a) } instance Functor TCM where fmap f (TCM g) = TCM $ \ctx i env -> let (ds, cs, i', env', x) = g ctx i env in (ds, cs, i', env', f x) instance Applicative TCM where pure x = TCM $ \_ i env -> (mempty, mempty, i, env, x) (<*>) = ap instance Monad TCM where TCM f >>= g = TCM $ \ctx i1 env1 -> let (ds2, cs2, i2, env2, x) = f ctx i1 env1 (ds3, cs3, i3, env3, y) = runTCM (g x) ctx i2 env2 in (ds2 <> ds3, cs2 <> cs3, i3, env3, y) raise :: Range -> String -> TCM () raise rng@(Range (Pos y _) _) msg = TCM $ \(fp, source) i env -> (pure (Diagnostic fp rng [] (lines source !! y) msg), mempty, i, env, ()) emit :: Constr -> TCM () emit c = TCM $ \_ i env -> (mempty, pure c, i, env, ()) collectConstraints :: (Constr -> Maybe b) -> TCM a -> TCM (Bag b, a) collectConstraints predicate (TCM f) = TCM $ \ctx i env -> let (ds, cs, i', env', x) = f ctx i env (yes, no) = bagPartition predicate cs in (ds, no, i', env', (yes, x)) getFullEnv :: TCM Env getFullEnv = TCM $ \_ i env -> (mempty, mempty, i, env, env) putFullEnv :: Env -> TCM () putFullEnv env = TCM $ \_ i _ -> (mempty, mempty, i, env, ()) genId :: TCM Int genId = TCM $ \_ i env -> (mempty, mempty, i, env, i) getKind :: Name -> TCM (Maybe CKind) getKind name = do Env tenv _ <- getFullEnv return (Map.lookup name tenv) getType :: Name -> TCM (Maybe CType) getType name = do Env _ venv <- getFullEnv return (Map.lookup name venv) modifyTEnv :: (Map Name CKind -> Map Name CKind) -> TCM () modifyTEnv f = do Env tenv venv <- getFullEnv putFullEnv (Env (f tenv) venv) modifyVEnv :: (Map Name CType -> Map Name CType) -> TCM () modifyVEnv f = do Env tenv venv <- getFullEnv putFullEnv (Env tenv (f venv)) scopeTEnv :: TCM a -> TCM a scopeTEnv m = do Env origtenv _ <- getFullEnv res <- m modifyTEnv (\_ -> origtenv) return res scopeVEnv :: TCM a -> TCM a scopeVEnv m = do Env _ origvenv <- getFullEnv res <- m modifyVEnv (\_ -> origvenv) return res genKUniVar :: TCM CKind genKUniVar = KExt () . KUniVar <$> genId genUniVar :: CKind -> TCM CType genUniVar k = TExt k . TUniVar <$> genId getKind' :: Range -> Name -> TCM CKind getKind' rng name = getKind name >>= \case Nothing -> do raise rng $ "Type not in scope: " ++ pretty name genKUniVar Just k -> return k getType' :: Range -> Name -> TCM CType getType' rng name = getType name >>= \case Nothing -> do raise rng $ "Variable not in scope: " ++ pretty name genUniVar (KType ()) Just k -> return k tcProgram :: PProgram -> TCM CProgram tcProgram (Program ddefs fdefs) = do (kconstrs, ddefs') <- collectConstraints isCEqK $ do mapM_ prepareDataDef ddefs mapM tcDataDef ddefs solveKindVars kconstrs fdefs' <- mapM tcFunDef fdefs return (Program ddefs' fdefs') prepareDataDef :: PDataDef -> TCM () prepareDataDef (DataDef _ name params _) = do parkinds <- mapM (\_ -> genKUniVar) params let k = foldr (KFun ()) (KType ()) parkinds modifyTEnv (Map.insert name k) -- Assumes that the kind of the name itself has already been registered with -- the correct arity (this is doen by prepareDataDef). tcDataDef :: PDataDef -> TCM CDataDef tcDataDef (DataDef rng name params cons) = do kd <- getKind' rng name let (pkinds, kret) = splitKind kd -- sanity checking; would be nicer to store these in prepareDataDef already when (length pkinds /= length params) $ error "tcDataDef: Invalid param kind list length" case kret of Right () -> return () _ -> error "tcDataDef: Invalid ret kind" cons' <- scopeTEnv $ do modifyTEnv (Map.fromList (zip (map snd params) pkinds) <>) mapM (\(cname, ty) -> (cname,) <$> mapM kcType ty) cons return (DataDef () name (zip pkinds (map snd params)) cons') kcType :: PType -> TCM CType kcType = \case TApp rng t ts -> do t' <- kcType t ts' <- mapM kcType ts retk <- genKUniVar let expected = foldr (KFun ()) retk (map extOf ts') emit $ CEqK (extOf t') expected rng return (TApp retk t' ts') TTup _ ts -> do ts' <- mapM kcType ts forM_ (zip (map extOf ts) ts') $ \(trng, ct) -> emit $ CEqK (extOf ct) (KType ()) trng return (TTup (KType ()) ts') TList _ t -> do t' <- kcType t emit $ CEqK (extOf t') (KType ()) (extOf t) return (TList (KType ()) t') TFun _ t1 t2 -> do t1' <- kcType t1 t2' <- kcType t2 emit $ CEqK (extOf t1') (KType ()) (extOf t1) emit $ CEqK (extOf t2') (KType ()) (extOf t2) return (TFun (KType ()) t1' t2') TCon rng n -> TCon <$> getKind' rng n <*> pure n TVar rng n -> TVar <$> getKind' rng n <*> pure n tcFunDef :: PFunDef -> TCM CFunDef tcFunDef (FunDef _ name msig eqs) = do when (not $ allEq (fmap (length . funeqPats) eqs)) $ raise (sconcatne (fmap extOf eqs)) "Function equations have differing numbers of arguments" typ <- case msig of TypeSig sig -> kcType sig TypeSigExt NoTypeSig -> genUniVar (KType ()) eqs' <- mapM (tcFunEq typ) eqs return (FunDef typ name (TypeSig typ) eqs') tcFunEq :: CType -> PFunEq -> TCM CFunEq tcFunEq = error "tcFunEq" newtype SolveM v t m a = SolveM (Map v (Bag t) -> Map v t -> m (a, Map v (Bag t), Map v t)) instance Monad m => Functor (SolveM v t m) where fmap f (SolveM g) = SolveM $ \m r -> do (x, m', r') <- g m r return (f x, m', r') instance Monad m => Applicative (SolveM v t m) where pure x = SolveM $ \m r -> return (x, m, r) (<*>) = ap instance Monad m => Monad (SolveM v t m) where SolveM f >>= g = SolveM $ \m r -> do (x, m1, r1) <- f m r let SolveM h = g x h m1 r1 solvemStateGet :: Monad m => SolveM v t m (Map v (Bag t)) solvemStateGet = SolveM $ \m r -> return (m, m, r) solvemStateUpdate :: Monad m => (Map v (Bag t) -> Map v (Bag t)) -> SolveM v t m () solvemStateUpdate f = SolveM $ \m r -> return ((), f m, r) solvemLogUpdate :: Monad m => (Map v t -> Map v t) -> SolveM v t m () solvemLogUpdate f = SolveM $ \m r -> return ((), m, f r) solvemStateVars :: Monad m => SolveM v t m [v] solvemStateVars = Map.keys <$> solvemStateGet solvemStateRHS :: (Ord v, Monad m) => v -> SolveM v t m (Bag t) solvemStateRHS v = fromMaybe mempty . Map.lookup v <$> solvemStateGet solvemStateSet :: (Ord v, Monad m) => v -> Bag t -> SolveM v t m () solvemStateSet v b = solvemStateUpdate (Map.insert v b) solvemLogEq :: (Ord v, Monad m) => v -> t -> SolveM v t m () solvemLogEq v t = solvemLogUpdate (Map.insert v t) solveKindVars :: Bag (CKind, CKind, Range) -> TCM () solveKindVars cs = do traceShowM cs traceShowM $ solveConstraints reduce (foldMap pure . kindUniVars) (\v repl -> substKind (Map.singleton v repl)) (\case KExt () (KUniVar v) -> Just v _ -> Nothing) kindSize (map (\(a, b, _) -> (a, b)) (toList cs)) where reduce :: CKind -> CKind -> (Bag (Int, CKind), Bag (CKind, CKind)) -- unification variables produce constraints on a unification variable reduce (KExt () (KUniVar i)) (KExt () (KUniVar j)) | i == j = mempty reduce (KExt () (KUniVar i)) k = (pure (i, k), mempty) reduce k (KExt () (KUniVar i)) = (pure (i, k), mempty) -- if lhs and rhs have equal prefixes, recurse reduce (KType ()) (KType ()) = mempty reduce (KFun () a b) (KFun () c d) = reduce a c <> reduce b d -- otherwise, this is a kind mismatch reduce k1 k2 = (mempty, pure (k1, k2)) kindSize :: CKind -> Int kindSize KType{} = 1 kindSize (KFun () a b) = 1 + kindSize a + kindSize b kindSize (KExt () KUniVar{}) = 1 solveConstrs :: Bag Constr -> (Bag Diagnostic, Map Name TType) solveConstrs = error "solveConstrs" substProg :: Map Name TType -> CProgram -> TProgram substProg = error "substProg" substKind :: Map Int CKind -> CKind -> CKind substKind _ k@KType{} = k substKind m (KFun () k1 k2) = KFun () (substKind m k1) (substKind m k2) substKind m k@(KExt () (KUniVar v)) = fromMaybe k (Map.lookup v m) kindUniVars :: CKind -> Set Int kindUniVars = \case KType{} -> mempty KFun () a b -> kindUniVars a <> kindUniVars b KExt () (KUniVar v) -> Set.singleton v allEq :: (Eq a, Foldable t) => t a -> Bool allEq l = case toList l of [] -> True x : xs -> all (== x) xs funeqPats :: FunEq t -> [Pattern t] funeqPats (FunEq _ _ pats _) = pats splitKind :: Kind s -> ([Kind s], Either (E Kind s) (X Kind s)) splitKind (KType x) = ([], Right x) splitKind (KFun _ k1 k2) = first (k1:) (splitKind k2) splitKind (KExt _ e) = ([], Left e) isCEqK :: Constr -> Maybe (CKind, CKind, Range) isCEqK (CEqK k1 k2 rng) = Just (k1, k2, rng) isCEqK _ = Nothing foldMapM :: (Applicative f, Monoid m, Foldable t) => (a -> f m) -> t a -> f m foldMapM f = getAp . foldMap (Ap . f)