aboutsummaryrefslogtreecommitdiff
path: root/src/HSVIS/Typecheck.hs
blob: de9d7dbbe7a15ffce88b411a6066e87dcd801cfb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE EmptyDataDeriving #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TupleSections #-}
module HSVIS.Typecheck where

import Control.Monad
import Data.Bifunctor (first, second)
import Data.Foldable (toList)
import Data.Map.Strict (Map)
import Data.Monoid (First(..))
import qualified Data.Map.Strict as Map

import Data.Bag
import Data.List.NonEmpty.Util
import HSVIS.AST
import HSVIS.Parser
import HSVIS.Diagnostic
import HSVIS.Pretty


data StageTC

type instance X DataDef StageTC = ()
type instance X FunDef  StageTC = CType
type instance X FunEq   StageTC = CType
type instance X Kind    StageTC = ()
type instance X Type    StageTC = CKind
type instance X Pattern StageTC = CType
type instance X RHS     StageTC = CType
type instance X Expr    StageTC = CType

data instance E Type StageTC = TUniVar Int deriving (Show)
data instance E Kind StageTC = KUniVar Int deriving (Show)

type CProgram = Program StageTC
type CDataDef = DataDef StageTC
type CFunDef  = FunDef  StageTC
type CFunEq   = FunEq   StageTC
type CKind    = Kind    StageTC
type CType    = Type    StageTC
type CPattern = Pattern StageTC
type CRHS     = RHS     StageTC
type CExpr    = Expr    StageTC

data StageTyped

type instance X DataDef StageTyped = TType
type instance X FunDef  StageTyped = TType
type instance X FunEq   StageTyped = TType
type instance X Kind    StageTyped = ()
type instance X Type    StageTyped = TKind
type instance X Pattern StageTyped = TType
type instance X RHS     StageTyped = TType
type instance X Expr    StageTyped = TType

data instance E Type StageTyped deriving (Show)
data instance E Kind StageTyped deriving (Show)

type TProgram = Program StageTyped
type TDataDef = DataDef StageTyped
type TFunDef  = FunDef  StageTyped
type TFunEq   = FunEq   StageTyped
type TKind    = Kind    StageTyped
type TType    = Type    StageTyped
type TPattern = Pattern StageTyped
type TRHS     = RHS     StageTyped
type TExpr    = Expr    StageTyped


typecheck :: FilePath -> String -> PProgram -> ([Diagnostic], Program TType)
typecheck fp source prog =
  let (ds1, cs, _, _, progtc) =
        runTCM (tcProgram prog) (fp, source) 1 (Env mempty mempty)
      (ds2, sub) = solveConstrs cs
  in (toList (ds1 <> ds2), substProg sub progtc)

data Constr
  -- Equality constraints: "left" must be equal to "right" because of the thing
  -- at the given range. "left" is the expected thing; "right" is the observed
  -- thing.
  = CEq CType CType Range
  | CEqK CKind CKind Range
  deriving (Show)

data Env = Env (Map Name CKind) (Map Name CType)
  deriving (Show)

newtype TCM a = TCM {
  runTCM :: (FilePath, String)  -- ^ reader context: file and file contents
         -> Int  -- ^ state: next id to generate
         -> Env  -- ^ state: type and value environment
         -> (Bag Diagnostic  -- ^ writer: diagnostics
            ,Bag Constr      -- ^ writer: constraints
            ,Int, Env, a)
  }

instance Functor TCM where
  fmap f (TCM g) = TCM $ \ctx i env ->
    let (ds, cs, i', env', x) = g ctx i env
    in (ds, cs, i', env', f x)

instance Applicative TCM where
  pure x = TCM $ \_ i env -> (mempty, mempty, i, env, x)
  (<*>) = ap

instance Monad TCM where
  TCM f >>= g = TCM $ \ctx i1 env1 ->
    let (ds2, cs2, i2, env2, x) = f ctx i1 env1
        (ds3, cs3, i3, env3, y) = runTCM (g x) ctx i2 env2
    in (ds2 <> ds3, cs2 <> cs3, i3, env3, y)

raise :: Range -> String -> TCM ()
raise rng@(Range (Pos y _) _) msg = TCM $ \(fp, source) i env ->
  (pure (Diagnostic fp rng [] (lines source !! y) msg), mempty, i, env, ())

emit :: Constr -> TCM ()
emit c = TCM $ \_ i env -> (mempty, pure c, i, env, ())

collectConstraints :: (Constr -> Maybe b) -> TCM a -> TCM (Bag b, a)
collectConstraints predicate (TCM f) = TCM $ \ctx i env ->
  let (ds, cs, i', env', x) = f ctx i env
      (yes, no) = bagPartition predicate cs
  in (ds, no, i', env', (yes, x))

getFullEnv :: TCM Env
getFullEnv = TCM $ \_ i env -> (mempty, mempty, i, env, env)

putFullEnv :: Env -> TCM ()
putFullEnv env = TCM $ \_ i _ -> (mempty, mempty, i, env, ())

genId :: TCM Int
genId = TCM $ \_ i env -> (mempty, mempty, i, env, i)

getKind :: Name -> TCM (Maybe CKind)
getKind name = do
  Env tenv _ <- getFullEnv
  return (Map.lookup name tenv)

getType :: Name -> TCM (Maybe CType)
getType name = do
  Env _ venv <- getFullEnv
  return (Map.lookup name venv)

modifyTEnv :: (Map Name CKind -> Map Name CKind) -> TCM ()
modifyTEnv f = do
  Env tenv venv <- getFullEnv
  putFullEnv (Env (f tenv) venv)

modifyVEnv :: (Map Name CType -> Map Name CType) -> TCM ()
modifyVEnv f = do
  Env tenv venv <- getFullEnv
  putFullEnv (Env tenv (f venv))

scopeTEnv :: TCM a -> TCM a
scopeTEnv m = do
  Env origtenv _ <- getFullEnv
  res <- m
  modifyTEnv (\_ -> origtenv)
  return res

scopeVEnv :: TCM a -> TCM a
scopeVEnv m = do
  Env _ origvenv <- getFullEnv
  res <- m
  modifyVEnv (\_ -> origvenv)
  return res

genKUniVar :: TCM CKind
genKUniVar = KExt () . KUniVar <$> genId

genUniVar :: CKind -> TCM CType
genUniVar k = TExt k . TUniVar <$> genId

getKind' :: Range -> Name -> TCM CKind
getKind' rng name = getKind name >>= \case
  Nothing -> do
    raise rng $ "Type not in scope: " ++ pretty name
    genKUniVar
  Just k -> return k

getType' :: Range -> Name -> TCM CType
getType' rng name = getType name >>= \case
  Nothing -> do
    raise rng $ "Variable not in scope: " ++ pretty name
    genUniVar (KType ())
  Just k -> return k

tcProgram :: PProgram -> TCM CProgram
tcProgram (Program ddefs fdefs) = do
  (kconstrs, ddefs') <- collectConstraints isCEqK $ do
    mapM_ prepareDataDef ddefs
    mapM tcDataDef ddefs

  solveKindVars kconstrs

  fdefs' <- mapM tcFunDef fdefs

  return (Program ddefs' fdefs')

prepareDataDef :: PDataDef -> TCM ()
prepareDataDef (DataDef _ name params _) = do
  parkinds <- mapM (\_ -> genKUniVar) params
  let k = foldr (KFun ()) (KType ()) parkinds
  modifyTEnv (Map.insert name k)

-- Assumes that the kind of the name itself has already been registered with
-- the correct arity (this is doen by prepareDataDef).
tcDataDef :: PDataDef -> TCM CDataDef
tcDataDef (DataDef rng name params cons) = do
  kd <- getKind' rng name
  let (pkinds, kret) = splitKind kd

  -- sanity checking; would be nicer to store these in prepareDataDef already
  when (length pkinds /= length params) $ error "tcDataDef: Invalid param kind list length"
  case kret of Right () -> return ()
               _ -> error "tcDataDef: Invalid ret kind"

  cons' <- scopeTEnv $ do
    modifyTEnv (Map.fromList (zip (map snd params) pkinds) <>)
    mapM (\(cname, ty) -> (cname,) <$> mapM kcType ty) cons
  return (DataDef () name (zip pkinds (map snd params)) cons')

kcType :: PType -> TCM CType
kcType = \case
  TApp rng t ts -> do
    t' <- kcType t
    ts' <- mapM kcType ts
    retk <- genKUniVar
    let expected = foldr (KFun ()) retk (map extOf ts')
    emit $ CEqK (extOf t') expected rng
    return (TApp retk t' ts')

  TTup _ ts -> do
    ts' <- mapM kcType ts
    forM_ (zip (map extOf ts) ts') $ \(trng, ct) ->
      emit $ CEqK (extOf ct) (KType ()) trng
    return (TTup (KType ()) ts')

  TList _ t -> do
    t' <- kcType t
    emit $ CEqK (extOf t') (KType ()) (extOf t)
    return (TList (KType ()) t')

  TFun _ t1 t2 -> do
    t1' <- kcType t1
    t2' <- kcType t2
    emit $ CEqK (extOf t1') (KType ()) (extOf t1)
    emit $ CEqK (extOf t2') (KType ()) (extOf t2)
    return (TFun (KType ()) t1' t2')

  TCon rng n -> TCon <$> getKind' rng n <*> pure n

  TVar rng n -> TVar <$> getKind' rng n <*> pure n

tcFunDef :: PFunDef -> TCM CFunDef
tcFunDef (FunDef _ name msig eqs) = do
  when (not $ allEq (fmap (length . funeqPats) eqs)) $
    raise (sconcatne (fmap extOf eqs)) "Function equations have differing numbers of arguments"

  typ <- case msig of
    TypeSig sig -> kcType sig
    TypeSigExt NoTypeSig -> genUniVar (KType ())

  eqs' <- mapM (tcFunEq typ) eqs

  return (FunDef typ name (TypeSig typ) eqs')

tcFunEq :: CType -> PFunEq -> TCM CFunEq
tcFunEq = _

solveKindVars :: Bag (CKind, CKind, Range) -> TCM ()
solveKindVars =
  mapM_ $ \(a, b, rng) -> do
    let (subst, First merr) = reduce a b
    forM_ merr $ \(erra, errb) ->
      raise rng $
        "Kind mismatch:\n\
        \- Expected: " ++ pretty a ++ "\n\
        \- Observed: " ++ pretty b ++ "\n\
        \because '" ++ pretty erra ++ "' and '" ++ pretty errb ++ "' don't match"
    let collected :: [(Int, Bag CKind)]
        collected = Map.assocs $ Map.fromListWith (<>) (fmap (second pure) (toList subst))
    _
  where
    reduce :: CKind -> CKind -> (Bag (Int, CKind), First (CKind, CKind))
    reduce (KType ()) (KType ()) = mempty
    reduce (KFun () a b) (KFun () c d) = reduce a c <> reduce b d
    reduce (KExt () (KUniVar i)) k = (pure (i, k), mempty)
    reduce k (KExt () (KUniVar i)) = (pure (i, k), mempty)
    reduce k1 k2 = (mempty, pure (k1, k2))

allEq :: (Eq a, Foldable t) => t a -> Bool
allEq l = case toList l of
            [] -> True
            x : xs -> all (== x) xs

funeqPats :: FunEq t -> [Pattern t]
funeqPats (FunEq _ _ pats _) = pats

splitKind :: Kind s -> ([Kind s], Either (E Kind s) (X Kind s))
splitKind (KType x) = ([], Right x)
splitKind (KFun _ k1 k2) = first (k1:) (splitKind k2)
splitKind (KExt _ e) = ([], Left e)

isCEqK :: Constr -> Maybe (CKind, CKind, Range)
isCEqK (CEqK k1 k2 rng) = Just (k1, k2, rng)
isCEqK _ = Nothing