aboutsummaryrefslogtreecommitdiff
path: root/src/HSVIS/Typecheck.hs
blob: f62b09731a4befef2c0b6c32549f00f63aa77a19 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE EmptyDataDeriving #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE GADTs #-}
module HSVIS.Typecheck (
  StageTyped,
  typecheck,
  -- * Typed AST synonyms
  -- TProgram, TDataDef, TFunDef, TFunEq, TKind, TType, TPattern, TRHS, TExpr,
) where

import Control.Monad
import Data.Bifunctor (first, second)
import Data.Foldable (toList)
import Data.List (find, inits)
import Data.Map.Strict (Map)
import Data.Maybe (fromMaybe)
import Data.Monoid (Ap(..))
import qualified Data.Map.Strict as Map
import Data.Tuple (swap)
import Data.Set (Set)
import qualified Data.Set as Set

import Debug.Trace

import Data.Bag
import Data.List.NonEmpty.Util
import HSVIS.AST
import HSVIS.Parser
import HSVIS.Diagnostic
import HSVIS.Pretty
import HSVIS.Typecheck.Solve


data StageTC

type instance X DataDef StageTC = CKind
type instance X FunDef  StageTC = CType
type instance X FunEq   StageTC = ()
type instance X Kind    StageTC = ()
type instance X Type    StageTC = CKind
type instance X Pattern StageTC = CType
type instance X RHS     StageTC = CType
type instance X Expr    StageTC = CType

data instance E Type StageTC = TUniVar Int deriving (Show)
data instance E Kind StageTC = KUniVar Int deriving (Show, Eq, Ord)
data instance E TypeSig StageTC deriving (Show)

type CProgram = Program StageTC
type CDataDef = DataDef StageTC
type CFunDef  = FunDef  StageTC
type CFunEq   = FunEq   StageTC
type CKind    = Kind    StageTC
type CType    = Type    StageTC
type CPattern = Pattern StageTC
type CRHS     = RHS     StageTC
type CExpr    = Expr    StageTC

data StageTyped

type instance X DataDef StageTyped = TType
type instance X FunDef  StageTyped = TType
type instance X FunEq   StageTyped = ()
type instance X Kind    StageTyped = ()
type instance X Type    StageTyped = TKind
type instance X Pattern StageTyped = TType
type instance X RHS     StageTyped = TType
type instance X Expr    StageTyped = TType

data instance E Type StageTyped deriving (Show)
data instance E Kind StageTyped deriving (Show)
data instance E TypeSig StageTyped deriving (Show)

type TProgram = Program StageTyped
type TDataDef = DataDef StageTyped
type TFunDef  = FunDef  StageTyped
type TFunEq   = FunEq   StageTyped
type TKind    = Kind    StageTyped
type TType    = Type    StageTyped
type TPattern = Pattern StageTyped
type TRHS     = RHS     StageTyped
type TExpr    = Expr    StageTyped

instance Pretty (E Type StageTC) where
  prettysPrec _ (TUniVar n) = showString ("?t" ++ show n)

instance Pretty (E Kind StageTC) where
  prettysPrec _ (KUniVar n) = showString ("?k" ++ show n)


typecheck :: FilePath -> String -> PProgram -> ([Diagnostic], TProgram)
typecheck fp source prog =
  let (ds1, cs, _, _, progtc) =
        runTCM (tcProgram prog) (fp, source) 1 (Env mempty mempty mempty)
      (ds2, subK, subT) = solveConstrs cs
  in (toList (ds1 <> ds2), doneProg subK subT progtc)

data Constr
  -- Equality constraints: "left" must be equal to "right" because of the thing
  -- at the given range. "left" is the expected thing; "right" is the observed
  -- thing.
  = CEq CType CType Range
  | CEqK CKind CKind Range
  deriving (Show)

data Env = Env (Map Name CKind)   -- ^ types in scope (including variables)
               (Map Name CType)   -- ^ values in scope (constructors and variables)
               (Map Name InvCon)  -- ^ patterns in scope (inverse constructors)
  deriving (Show)

data InvCon = InvCon (Map Name CKind)  -- ^ universally quantified type variables
                     CType    -- ^ input type of the inverse constructor (result of the constructor)
                     [CType]  -- ^ field types
  deriving (Show)

newtype TCM a = TCM {
  runTCM :: (FilePath, String)  -- ^ reader context: file and file contents
         -> Int  -- ^ state: next id to generate
         -> Env  -- ^ state: type and value environment
         -> (Bag Diagnostic  -- ^ writer: diagnostics
            ,Bag Constr      -- ^ writer: constraints
            ,Int, Env, a)
  }

instance Functor TCM where
  fmap f (TCM g) = TCM $ \ctx i env ->
    let (ds, cs, i', env', x) = g ctx i env
    in (ds, cs, i', env', f x)

instance Applicative TCM where
  pure x = TCM $ \_ i env -> (mempty, mempty, i, env, x)
  (<*>) = ap

instance Monad TCM where
  TCM f >>= g = TCM $ \ctx i1 env1 ->
    let (ds2, cs2, i2, env2, x) = f ctx i1 env1
        (ds3, cs3, i3, env3, y) = runTCM (g x) ctx i2 env2
    in (ds2 <> ds3, cs2 <> cs3, i3, env3, y)

raise :: Severity -> Range -> String -> TCM ()
raise sev rng@(Range (Pos y _) _) msg = TCM $ \(fp, source) i env ->
  (pure (Diagnostic sev fp rng [] (lines source !! y) msg), mempty, i, env, ())

emit :: Constr -> TCM ()
emit c = TCM $ \_ i env -> (mempty, pure c, i, env, ())

collectConstraints :: (Constr -> Maybe b) -> TCM a -> TCM (Bag b, a)
collectConstraints predicate (TCM f) = TCM $ \ctx i env ->
  let (ds, cs, i', env', x) = f ctx i env
      (yes, no) = bagPartition predicate cs
  in (ds, no, i', env', (yes, x))

getFullEnv :: TCM Env
getFullEnv = TCM $ \_ i env -> (mempty, mempty, i, env, env)

putFullEnv :: Env -> TCM ()
putFullEnv env = TCM $ \_ i _ -> (mempty, mempty, i, env, ())

genId :: TCM Int
genId = TCM $ \_ i env -> (mempty, mempty, i + 1, env, i)

getKind :: Name -> TCM (Maybe CKind)
getKind name = do
  Env tenv _ _ <- getFullEnv
  return (Map.lookup name tenv)

getType :: Name -> TCM (Maybe CType)
getType name = do
  Env _ venv _ <- getFullEnv
  return (Map.lookup name venv)

getInvCon :: Name -> TCM (Maybe InvCon)
getInvCon name = do
  Env _ _ icenv <- getFullEnv
  return (Map.lookup name icenv)

modifyTEnv :: (Map Name CKind -> Map Name CKind) -> TCM ()
modifyTEnv f = do
  Env tenv venv icenv <- getFullEnv
  putFullEnv (Env (f tenv) venv icenv)

modifyVEnv :: (Map Name CType -> Map Name CType) -> TCM ()
modifyVEnv f = do
  Env tenv venv icenv <- getFullEnv
  putFullEnv (Env tenv (f venv) icenv)

modifyICEnv :: (Map Name InvCon -> Map Name InvCon) -> TCM ()
modifyICEnv f = do
  Env tenv venv icenv <- getFullEnv
  putFullEnv (Env tenv venv (f icenv))

scopeTEnv :: TCM a -> TCM a
scopeTEnv m = do
  Env origtenv _ _ <- getFullEnv
  res <- m
  modifyTEnv (\_ -> origtenv)
  return res

scopeVEnv :: TCM a -> TCM a
scopeVEnv m = do
  Env _ origvenv _ <- getFullEnv
  res <- m
  modifyVEnv (\_ -> origvenv)
  return res

genKUniVar :: TCM CKind
genKUniVar = KExt () . KUniVar <$> genId

genUniVar :: CKind -> TCM CType
genUniVar k = TExt k . TUniVar <$> genId

getKind' :: Range -> Name -> TCM CKind
getKind' rng name = getKind name >>= \case
  Nothing -> do
    raise SError rng $ "Type not in scope: " ++ pretty name
    genKUniVar
  Just k -> return k

getType' :: Range -> Name -> TCM CType
getType' rng name = getType name >>= \case
  Nothing -> do
    raise SError rng $ "Variable not in scope: " ++ pretty name
    genUniVar (KType ())
  Just k -> return k

tcProgram :: PProgram -> TCM CProgram
tcProgram (Program ddefs1 fdefs1) = do
  (kconstrs, ddefs2) <- collectConstraints isCEqK $ do
    ks <- mapM prepareDataDef ddefs1
    zipWithM kcDataDef ks ddefs1

  kinduvars <- solveKindVars kconstrs
  let ddefs3 = map (substDdef kinduvars mempty) ddefs2
  modifyTEnv (Map.map (substKind kinduvars))

  forM_ ddefs3 $ \ddef ->
    modifyICEnv (Map.fromList (generateInvCons ddef) <>)

  traceM (unlines (map pretty ddefs3))

  fdefs2 <- mapM tcFunDef fdefs1

  return (Program ddefs3 fdefs2)

-- Bring data type name in scope with a kind of the specified arity
prepareDataDef :: PDataDef -> TCM (CKind, [CKind])
prepareDataDef (DataDef _ name params _) = do
  parkinds <- mapM (\_ -> genKUniVar) params
  let k = foldr (KFun ()) (KType ()) parkinds
  modifyTEnv (Map.insert name k)
  return (k, parkinds)

-- Assumes that the kind of the name itself has already been registered with
-- the correct arity (this is done by prepareDataDef).
kcDataDef :: (CKind, [CKind]) -> PDataDef -> TCM CDataDef
kcDataDef (kd, parkinds) (DataDef _ name params cons) = do
  -- ensure unicity of type param names
  params' <-
    let prenames = Set.fromList (map snd params)
        namegen = filter (`Set.notMember` prenames) [Name ('t' : show i) | i <- [1::Int ..]]
    in forM (zip3 params (inits (map snd params)) namegen) $ \((rng, pname), previous, replname) ->
         if pname `elem` previous
           then do raise SError rng $ "Duplicate type parameter: " ++ pretty pname
                   return replname
           else return pname

  -- kind-check the constructors
  cons' <- scopeTEnv $ do
    modifyTEnv (Map.fromList (zip params' parkinds) <>)
    forM cons $ \(cname, fieldtys) -> do
      fieldtys' <- mapM (kcType (Just (KType ()))) fieldtys
      return (cname, fieldtys')

  return (DataDef kd name (zip parkinds params') cons')

generateInvCons :: CDataDef -> [(Name, InvCon)]
generateInvCons (DataDef k tname params cons) =
  let tyvars = Map.fromList (map swap params)
      resty = TApp (KType ()) (TCon k tname) (map (uncurry TVar) params)
  in [(cname, InvCon tyvars resty fields) | (cname, fields) <- cons]

promoteDownK :: Maybe CKind -> TCM CKind
promoteDownK Nothing = genKUniVar
promoteDownK (Just k) = return k

downEqK :: Range -> Maybe CKind -> CKind -> TCM ()
downEqK _ Nothing _ = return ()
downEqK rng (Just k1) k2 = emit $ CEqK k1 k2 rng

-- | Given (maybe) the expected kind of this type, and a type, check it for
-- kind-correctness.
kcType :: Maybe CKind -> PType -> TCM CType
kcType mdown = \case
  TApp rng t ts -> do
    t' <- kcType Nothing t
    ts' <- mapM (kcType Nothing) ts
    retk <- promoteDownK mdown
    let expected = foldr (KFun ()) retk (map extOf ts')
    emit $ CEqK (extOf t') expected rng
    return (TApp retk t' ts')

  TTup rng ts -> do
    ts' <- mapM (kcType (Just (KType ()))) ts
    forM_ (zip (map extOf ts) ts') $ \(trng, ct) ->
      emit $ CEqK (extOf ct) (KType ()) trng
    downEqK rng mdown (KType ())
    return (TTup (KType ()) ts')

  TList rng t -> do
    t' <- kcType (Just (KType ())) t
    emit $ CEqK (extOf t') (KType ()) (extOf t)
    downEqK rng mdown (KType ())
    return (TList (KType ()) t')

  TFun rng t1 t2 -> do
    t1' <- kcType (Just (KType ())) t1
    t2' <- kcType (Just (KType ())) t2
    emit $ CEqK (extOf t1') (KType ()) (extOf t1)
    emit $ CEqK (extOf t2') (KType ()) (extOf t2)
    downEqK rng mdown (KType ())
    return (TFun (KType ()) t1' t2')

  TCon rng n -> do
    k <- getKind' rng n
    downEqK rng mdown k
    return (TCon k n)

  TVar rng n -> do
    k <- getKind' rng n
    downEqK rng mdown k
    return (TVar k n)

  TForall rng n t -> do  -- implicit forall
    k1 <- genKUniVar
    k2 <- genKUniVar
    downEqK rng mdown k2
    t' <- scopeTEnv $ do
      modifyTEnv (Map.insert n k1)
      kcType (Just k2) t
    return (TForall k2 n t')  -- not 'k1 -> k2' because the forall is implicit

tcFunDef :: PFunDef -> TCM CFunDef
tcFunDef (FunDef rng name msig eqs) = do
  when (not $ allEq (fmap (length . funeqPats) eqs)) $
    raise SError rng "Function equations have differing numbers of arguments"

  typ <- case msig of
    TypeSig sig -> kcType (Just (KType ())) sig
    TypeSigExt NoTypeSig -> genUniVar (KType ())

  eqs' <- mapM (tcFunEq typ) eqs

  return (FunDef typ name (TypeSig typ) eqs')

tcFunEq :: CType -> PFunEq -> TCM CFunEq
tcFunEq down (FunEq rng name pats rhs) = do
  -- getFullEnv >>= \env -> traceM $ "[tcFunEq] Env = " ++ show env
  (argtys, rhsty) <- unfoldFunTy rng (length pats) down
  scopeVEnv $ do
    pats' <- zipWithM tcPattern argtys pats
    rhs' <- tcRHS rhsty rhs
    return (FunEq () name pats' rhs')

tcPattern :: CType -> PPattern -> TCM CPattern
tcPattern down = \case
  PWildcard _ -> return $ PWildcard down
  PVar _ n -> modifyVEnv (Map.insert n down) >> return (PVar down n)
  PAs _ n p -> modifyVEnv (Map.insert n down) >> tcPattern down p
  PCon rng n ps ->
    getInvCon n >>= \case
      Just (InvCon tyvars match fields) -> do
        unisub <- mapM genUniVar tyvars  -- substitution for the universally quantified variables
        let match' = substType mempty mempty unisub match
            fields' = map (substType mempty mempty unisub) fields
        emit $ CEq down match' rng
        PCon match' n <$> zipWithM tcPattern fields' ps
      Nothing -> do
        raise SError rng $ "Constructor not in scope: " ++ pretty n
        return (PWildcard down)
  POp rng p1 op p2 ->
    case op of
      OCons -> do
        eltty <- genUniVar (KType ())
        let listty = TList (KType ()) eltty
        emit $ CEq down listty rng
        p1' <- tcPattern eltty p1
        p2' <- tcPattern listty p2
        return (POp listty p1' OCons p2')
      _ -> do
        raise SError rng $ "Operator is not a constructor: " ++ pretty op
        return (PWildcard down)
  PList rng ps -> do
    eltty <- genUniVar (KType ())
    let listty = TList (KType ()) eltty
    emit $ CEq down listty rng
    PList listty <$> mapM (tcPattern eltty) ps
  PTup rng ps -> do
    ts <- mapM (\_ -> genUniVar (KType ())) ps
    emit $ CEq down (TTup (KType ()) ts) rng
    PTup (TTup (KType ()) ts) <$> zipWithM tcPattern ts ps

tcRHS :: CType -> PRHS -> TCM CRHS
tcRHS = error "tcRHS"

unfoldFunTy :: Range -> Int -> CType -> TCM ([CType], CType)
unfoldFunTy _ n t | n <= 0 = return ([], t)
unfoldFunTy rng n (TFun _ t1 t2) = do
  (params, core) <- unfoldFunTy rng (n - 1) t2
  return (t1 : params, core)
unfoldFunTy rng n t = do
  vars <- replicateM n (genUniVar (KType ()))
  core <- genUniVar (KType ())
  let expected = foldr (TFun (KType ())) core vars
  emit $ CEq expected t rng
  return (vars, core)

solveKindVars :: Bag (CKind, CKind, Range) -> TCM (Map Int CKind)
solveKindVars cs = do
  let (asg, errs) =
        solveConstraints
          reduce
          (foldMap pure . kindUniVars)
          substKind
          (\case KExt () (KUniVar v) -> Just v
                 _ -> Nothing)
          kindSize
          (toList cs)

  forM_ errs $ \case
    UEUnequal k1 k2 rng ->
      raise SError rng $
        "Kind mismatch:\n\
        \- " ++ pretty k1 ++ "\n\
        \- " ++ pretty k2
    UERecursive uvar k rng ->
      raise SError rng $
        "Kind cannot be recursive: " ++ pretty (KExt () (KUniVar uvar)) ++ " = " ++ pretty k

  -- default unconstrained kind variables to Type
  let unconstrKUVars = foldMap kindUniVars (Map.elems asg) Set.\\ Map.keysSet asg
      defaults = Map.fromList (map (,KType ()) (toList unconstrKUVars))
      asg' = Map.map (substKind defaults) asg <> defaults

  return asg'
  where
    reduce :: CKind -> CKind -> Range -> (Bag (Int, CKind, Range), Bag (CKind, CKind, Range))
    reduce lhs rhs rng = case (lhs, rhs) of
      -- unification variables produce constraints on a unification variable
      (KExt () (KUniVar i), KExt () (KUniVar j)) | i == j -> mempty
      (KExt () (KUniVar i), k                  ) -> (pure (i, k, rng), mempty)
      (k                  , KExt () (KUniVar i)) -> (pure (i, k, rng), mempty)

      -- if lhs and rhs have equal prefixes, recurse
      (KType ()   , KType ()   ) -> mempty
      (KFun () a b, KFun () c d) -> reduce a c rng <> reduce b d rng

      -- otherwise, this is a kind mismatch
      (k1, k2) -> (mempty, pure (k1, k2, rng))

    kindSize :: CKind -> Int
    kindSize KType{} = 1
    kindSize (KFun () a b) = 1 + kindSize a + kindSize b
    kindSize (KExt () KUniVar{}) = 2

solveConstrs :: Bag Constr -> (Bag Diagnostic, Map Int TKind, Map Int TType)
solveConstrs = error "solveConstrs"

-- substitute unification variables
substProg :: Map Int CKind  -- ^ Kind unification variable instantiations
          -> Map Int CType  -- ^ Type unification variable instantiations
          -> CProgram
          -> CProgram
substProg = error "substProg"

-- substitute unification variables
substDdef :: Map Int CKind -> Map Int CType -> CDataDef -> CDataDef
substDdef mk mt (DataDef k name pars cons) =
  DataDef (substKind mk k) name
          (map (first (substKind mk)) pars)
          (map (second (map (substType mk mt mempty))) cons)

substType :: Map Int CKind  -- ^ kind uvars
          -> Map Int CType  -- ^ type uvars
          -> Map Name CType  -- ^ type variables
          -> CType -> CType
substType mk mt mtv = go
  where
    go (TApp k t ts) = TApp (substKind mk k) (go t) (map go ts)
    go (TTup k ts) = TTup (substKind mk k) (map go ts)
    go (TList k t) = TList (substKind mk k) (go t)
    go (TFun k t1 t2) = TFun (substKind mk k) (go t1) (go t2)
    go (TCon k n) = TCon (substKind mk k) n
    go t@(TVar _ n) = fromMaybe t (Map.lookup n mtv)
    go (TForall k n t) = TForall (substKind mk k) n (go t)
    go t@(TExt _ (TUniVar v)) = fromMaybe t (Map.lookup v mt)

-- substitute unification variables
substKind :: Map Int CKind -> CKind -> CKind
substKind m = \case
  KType () -> KType ()
  KFun () k1 k2 -> KFun () (substKind m k1) (substKind m k2)
  k@(KExt () (KUniVar v)) -> fromMaybe k (Map.lookup v m)

doneProg :: Map Int TKind  -- ^ Kind unification variable instantiations
         -> Map Int TType  -- ^ Type unification variable instantiations
         -> CProgram
         -> TProgram
doneProg = error "doneProg"

kindUniVars :: CKind -> Set Int
kindUniVars = \case
  KType{} -> mempty
  KFun () a b -> kindUniVars a <> kindUniVars b
  KExt () (KUniVar v) -> Set.singleton v

allEq :: (Eq a, Foldable t) => t a -> Bool
allEq l = case toList l of
            [] -> True
            x : xs -> all (== x) xs

funeqPats :: FunEq t -> [Pattern t]
funeqPats (FunEq _ _ pats _) = pats

isCEqK :: Constr -> Maybe (CKind, CKind, Range)
isCEqK (CEqK k1 k2 rng) = Just (k1, k2, rng)
isCEqK _ = Nothing

foldMapM :: (Applicative f, Monoid m, Foldable t) => (a -> f m) -> t a -> f m
foldMapM f = getAp . foldMap (Ap . f)