diff options
Diffstat (limited to 'src/HSVIS/Typecheck.hs')
-rw-r--r-- | src/HSVIS/Typecheck.hs | 63 |
1 files changed, 63 insertions, 0 deletions
diff --git a/src/HSVIS/Typecheck.hs b/src/HSVIS/Typecheck.hs new file mode 100644 index 0000000..b1ffbb9 --- /dev/null +++ b/src/HSVIS/Typecheck.hs @@ -0,0 +1,63 @@ +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +module HSVIS.Typecheck where + +import Control.Monad +import Control.Monad.Reader +import Control.Monad.State +import Control.Monad.Writer.CPS +import Data.Foldable (toList) +import qualified Data.List.NonEmpty as NE + +import Data.Bag +import HSVIS.AST +import HSVIS.Diagnostic + + +typecheck :: FilePath -> String -> Program Range -> ([Diagnostic], Program Type) +typecheck fp source prog = + let (progtc, (ds1, cs)) = + runWriter + . flip evalStateT 1 + . flip runReaderT (fp, source) + . runTCM + $ tcProgram prog + (ds2, sub) = solveConstrs cs + in (toList (ds1 <> ds2), substProg sub progtc) + +data Constr = + CEq Type Type Range -- ^ These types must be equal because of the expression here + +newtype TCM a = TCM { runTCM :: ReaderT (FilePath, String) (StateT Int (Writer (Bag Diagnostic, Bag Constr))) a } + deriving newtype (Functor, Applicative, Monad) + +raise :: Range -> String -> TCM () +raise rng@(Range (Pos y _) _) msg = do + (fp, source) <- ask + TCM $ lift $ lift $ tell (pure (Diagnostic fp rng [] (source !! y) msg), mempty) + +tcProgram :: Program Range -> TCM (Program TCType) +tcProgram (Program ddefs fdefs) = Program ddefs <$> traverse tcFunDef fdefs + +tcFunDef :: FunDef Range -> TCM (FunDef TCType) +tcFunDef (FunDef name msig eqs) = do + when (not $ allEq (fmap (length . funeqPats) eqs)) $ + raise (sconcatne (fmap funeqRange eqs)) "Function equations have differing numbers of arguments" + + _ + +allEq :: (Eq a, Foldable t) => t a -> Bool +allEq l = case toList l of + [] -> True + x : xs -> all (== x) xs + +funeqRange :: FunEq t -> t +funeqRange (FunEq rng _ _ _) = rng + +funeqPats :: FunEq t -> [Pattern t] +funeqPats (FunEq _ _ pats _) = pats + +sconcatne :: Semigroup a => NE.NonEmpty a -> a +sconcatne = \(x NE.:| xs) -> go x xs + where go a [] = a + go a (x : xs) = go (a <> x) xs |