{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE EmptyDataDeriving #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeFamilies #-} module HSVIS.Typecheck where import Control.Monad import Control.Monad.Reader import Control.Monad.State import Control.Monad.Writer.CPS import Data.Foldable (toList) import Data.Map.Strict (Map) import Data.Bag import Data.List.NonEmpty.Util import HSVIS.AST import HSVIS.Parser import HSVIS.Diagnostic data StageTC type instance X DataDef StageTC = CType type instance X FunDef StageTC = CType type instance X FunEq StageTC = CType type instance X Kind StageTC = () type instance X Type StageTC = CKind type instance X Pattern StageTC = CType type instance X RHS StageTC = CType type instance X Expr StageTC = CType data instance E Type StageTC = TUniVar Int deriving (Show) type CProgram = Program StageTC type CDataDef = DataDef StageTC type CFunDef = FunDef StageTC type CFunEq = FunEq StageTC type CKind = Kind StageTC type CType = Type StageTC type CPattern = Pattern StageTC type CRHS = RHS StageTC type CExpr = Expr StageTC data StageTyped type instance X DataDef StageTyped = TType type instance X FunDef StageTyped = TType type instance X FunEq StageTyped = TType type instance X Kind StageTyped = () type instance X Type StageTyped = TKind type instance X Pattern StageTyped = TType type instance X RHS StageTyped = TType type instance X Expr StageTyped = TType data instance E t StageTyped deriving (Show) type TProgram = Program StageTyped type TDataDef = DataDef StageTyped type TFunDef = FunDef StageTyped type TFunEq = FunEq StageTyped type TKind = Kind StageTyped type TType = Type StageTyped type TPattern = Pattern StageTyped type TRHS = RHS StageTyped type TExpr = Expr StageTyped typecheck :: FilePath -> String -> Program Range -> ([Diagnostic], Program TType) typecheck fp source prog = let (progtc, (ds1, cs)) = runWriter . flip evalStateT (Env mempty mempty) . flip evalStateT 1 . flip runReaderT (fp, source) . runTCM $ tcProgram prog (ds2, sub) = solveConstrs cs in (toList (ds1 <> ds2), substProg sub progtc) data Constr = CEq CType CType Range -- ^ These types must be equal because of the expression here deriving (Show) data Env = Env (Map Name CKind) (Map Name CType) deriving (Show) newtype TCM a = TCM { runTCM :: ReaderT (FilePath, String) (StateT Int (StateT Env (Writer (Bag Diagnostic, Bag Constr)))) a } deriving newtype (Functor, Applicative, Monad) raise :: Range -> String -> TCM () raise rng@(Range (Pos y _) _) msg = do (fp, source) <- ask TCM $ lift $ lift $ tell (pure (Diagnostic fp rng [] (source !! y) msg), mempty) modifyTEnv :: (Map Name CKind -> Map Name CKind) -> TCM () modifyTEnv f = TCM $ lift $ lift $ modify (\(Env a b) -> Env (f a) b) modifyVEnv :: (Map Name CType -> Map Name CType) -> TCM () modifyVEnv f = TCM $ lift $ lift $ modify (\(Env a b) -> Env a (f b)) genUniVar :: CKind -> TCM CType genUniVar k = TCM $ lift $ state (\i -> (TExt k (TUniVar i), i + 1)) tcProgram :: PProgram -> TCM CProgram tcProgram (Program ddefs fdefs) = do -- TODO: add preliminary kinds for the data definitions to the environment, -- then solve for the kind variables in checkDataDef Program <$> traverse checkDataDef ddefs <*> traverse tcFunDef fdefs checkDataDef :: PDataDef -> TCM CDataDef checkDataDef (DataDef rng name params cons) = do _ tcFunDef :: PFunDef -> TCM CFunDef tcFunDef (FunDef rng name msig eqs) = do when (not $ allEq (fmap (length . funeqPats) eqs)) $ raise (sconcatne (fmap extOf eqs)) "Function equations have differing numbers of arguments" typ <- case msig of Just sig -> checkType sig Nothing -> genUniVar (KType ()) _ checkType :: PType -> TCM CType checkType = \case TApp r t ts -> _ TTup r ts -> _ TList r t -> _ TFun r s t -> _ TCon r n -> _ TVar r n -> _ allEq :: (Eq a, Foldable t) => t a -> Bool allEq l = case toList l of [] -> True x : xs -> all (== x) xs funeqPats :: FunEq t -> [Pattern t] funeqPats (FunEq _ _ pats _) = pats