aboutsummaryrefslogtreecommitdiff
path: root/src/HSVIS/Typecheck.hs
blob: b1ffbb986378178fa254255ba409e0c176e817bc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module HSVIS.Typecheck where

import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer.CPS
import Data.Foldable (toList)
import qualified Data.List.NonEmpty as NE

import Data.Bag
import HSVIS.AST
import HSVIS.Diagnostic


typecheck :: FilePath -> String -> Program Range -> ([Diagnostic], Program Type)
typecheck fp source prog =
  let (progtc, (ds1, cs)) =
        runWriter
        . flip evalStateT 1
        . flip runReaderT (fp, source)
        . runTCM
        $ tcProgram prog
      (ds2, sub) = solveConstrs cs
  in (toList (ds1 <> ds2), substProg sub progtc)

data Constr =
  CEq Type Type Range  -- ^ These types must be equal because of the expression here

newtype TCM a = TCM { runTCM :: ReaderT (FilePath, String) (StateT Int (Writer (Bag Diagnostic, Bag Constr))) a }
  deriving newtype (Functor, Applicative, Monad)

raise :: Range -> String -> TCM ()
raise rng@(Range (Pos y _) _) msg = do
  (fp, source) <- ask
  TCM $ lift $ lift $ tell (pure (Diagnostic fp rng [] (source !! y) msg), mempty)

tcProgram :: Program Range -> TCM (Program TCType)
tcProgram (Program ddefs fdefs) = Program ddefs <$> traverse tcFunDef fdefs

tcFunDef :: FunDef Range -> TCM (FunDef TCType)
tcFunDef (FunDef name msig eqs) = do
  when (not $ allEq (fmap (length . funeqPats) eqs)) $
    raise (sconcatne (fmap funeqRange eqs)) "Function equations have differing numbers of arguments"

  _

allEq :: (Eq a, Foldable t) => t a -> Bool
allEq l = case toList l of
            [] -> True
            x : xs -> all (== x) xs

funeqRange :: FunEq t -> t
funeqRange (FunEq rng _ _ _) = rng

funeqPats :: FunEq t -> [Pattern t]
funeqPats (FunEq _ _ pats _) = pats

sconcatne :: Semigroup a => NE.NonEmpty a -> a
sconcatne = \(x NE.:| xs) -> go x xs
  where go a [] = a
        go a (x : xs) = go (a <> x) xs