{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} module HSVIS.Typecheck where import Control.Monad import Control.Monad.Reader import Control.Monad.State import Control.Monad.Writer.CPS import Data.Foldable (toList) import qualified Data.List.NonEmpty as NE import Data.Bag import HSVIS.AST import HSVIS.Diagnostic typecheck :: FilePath -> String -> Program Range -> ([Diagnostic], Program Type) typecheck fp source prog = let (progtc, (ds1, cs)) = runWriter . flip evalStateT 1 . flip runReaderT (fp, source) . runTCM $ tcProgram prog (ds2, sub) = solveConstrs cs in (toList (ds1 <> ds2), substProg sub progtc) data Constr = CEq Type Type Range -- ^ These types must be equal because of the expression here newtype TCM a = TCM { runTCM :: ReaderT (FilePath, String) (StateT Int (Writer (Bag Diagnostic, Bag Constr))) a } deriving newtype (Functor, Applicative, Monad) raise :: Range -> String -> TCM () raise rng@(Range (Pos y _) _) msg = do (fp, source) <- ask TCM $ lift $ lift $ tell (pure (Diagnostic fp rng [] (source !! y) msg), mempty) tcProgram :: Program Range -> TCM (Program TCType) tcProgram (Program ddefs fdefs) = Program ddefs <$> traverse tcFunDef fdefs tcFunDef :: FunDef Range -> TCM (FunDef TCType) tcFunDef (FunDef name msig eqs) = do when (not $ allEq (fmap (length . funeqPats) eqs)) $ raise (sconcatne (fmap funeqRange eqs)) "Function equations have differing numbers of arguments" _ allEq :: (Eq a, Foldable t) => t a -> Bool allEq l = case toList l of [] -> True x : xs -> all (== x) xs funeqRange :: FunEq t -> t funeqRange (FunEq rng _ _ _) = rng funeqPats :: FunEq t -> [Pattern t] funeqPats (FunEq _ _ pats _) = pats sconcatne :: Semigroup a => NE.NonEmpty a -> a sconcatne = \(x NE.:| xs) -> go x xs where go a [] = a go a (x : xs) = go (a <> x) xs