1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
|
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE EmptyDataDeriving #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE GADTs #-}
{-# OPTIONS -Wno-unused-top-binds #-}
{-# OPTIONS -Wno-unused-imports #-}
{-# LANGUAGE DataKinds #-}
module HSVIS.Typecheck (
StageTyped,
typecheck,
-- * Typed AST synonyms
-- TProgram, TDataDef, TFunDef, TFunEq, TKind, TType, TPattern, TRHS, TExpr,
) where
import Control.Monad
import Data.Bifunctor (first, second, bimap)
import Data.Foldable (toList)
import Data.List (find, inits)
import Data.Map.Strict (Map)
import Data.Maybe (fromMaybe)
import Data.Monoid (Ap(..))
import qualified Data.Map.Strict as Map
import Data.Tuple (swap)
import Data.Semigroup (First(..))
import Data.Set (Set)
import qualified Data.Set as Set
import GHC.Stack
import Debug.Trace
import Data.Bag
import Data.List.NonEmpty.Util
import Data.Map.Monoidal (MMap(..))
import qualified Data.Map.Monoidal as MMap
import HSVIS.AST
import HSVIS.Parser
import HSVIS.Diagnostic
import HSVIS.Pretty
import HSVIS.Typecheck.Solve
data StageTC
type instance X DataDef StageTC = CKind
type instance X FunDef StageTC = CType
type instance X FunEq StageTC = ()
type instance X Kind StageTC = ()
type instance X Type StageTC = CKind
type instance X Pattern StageTC = CType
type instance X RHS StageTC = CType
type instance X Expr StageTC = CType
data instance E Type StageTC = TUniVar Int deriving (Show, Eq, Ord)
data instance E Kind StageTC = KUniVar Int deriving (Show, Eq, Ord)
data instance E TypeSig StageTC deriving (Show)
type CProgram = Program StageTC
type CDataDef = DataDef StageTC
type CFunDef = FunDef StageTC
type CFunEq = FunEq StageTC
type CKind = Kind StageTC
type CType = Type StageTC
type CPattern = Pattern StageTC
type CRHS = RHS StageTC
type CExpr = Expr StageTC
data StageTyped
type instance X DataDef StageTyped = TKind
type instance X FunDef StageTyped = TType
type instance X FunEq StageTyped = ()
type instance X Kind StageTyped = ()
type instance X Type StageTyped = TKind
type instance X Pattern StageTyped = TType
type instance X RHS StageTyped = TType
type instance X Expr StageTyped = TType
data instance E Type StageTyped deriving (Show)
data instance E Kind StageTyped deriving (Show)
data instance E TypeSig StageTyped deriving (Show)
type TProgram = Program StageTyped
type TDataDef = DataDef StageTyped
type TFunDef = FunDef StageTyped
type TFunEq = FunEq StageTyped
type TKind = Kind StageTyped
type TType = Type StageTyped
type TPattern = Pattern StageTyped
type TRHS = RHS StageTyped
type TExpr = Expr StageTyped
instance Pretty (E Type StageTC) where
prettysPrec _ (TUniVar n) = showString ("?t" ++ show n)
instance Pretty (E Kind StageTC) where
prettysPrec _ (KUniVar n) = showString ("?k" ++ show n)
typecheck :: FilePath -> String -> PProgram -> ([Diagnostic], TProgram)
typecheck fp source prog =
let (ds, cs, _, _, resprog) =
runTCM (tcTop prog) (fp, source) 1 (Env mempty mempty mempty)
in trace ("[tc] resprog = " ++ show resprog) $
if not (null cs)
then error $ "Constraints left after typechecker completion: " ++ show cs
else (toList ds, resprog)
data Constr
-- Equality constraints: "left" must be equal to "right" because of the thing
-- at the given range. "left" is the expected thing; "right" is the observed
-- thing.
= CEq CType CType Range
| CEqK CKind CKind Range
deriving (Show)
data Env = Env (Map Name CKind) -- ^ types in scope (including variables)
(Map Name CType) -- ^ values in scope (constructors and variables)
(Map Name InvCon) -- ^ patterns in scope (inverse constructors)
deriving (Show)
data InvCon = InvCon (Map Name CKind) -- ^ universally quantified type variables
CType -- ^ input type of the inverse constructor (result of the constructor)
[CType] -- ^ field types
deriving (Show)
newtype TCM a = TCM {
runTCM :: (FilePath, String) -- ^ reader context: file and file contents
-> Int -- ^ state: next id to generate
-> Env -- ^ state: type and value environment
-> (Bag Diagnostic -- ^ writer: diagnostics
,Bag Constr -- ^ writer: constraints
,Int, Env, a)
}
instance Functor TCM where
fmap f (TCM g) = TCM $ \ctx i env ->
let (ds, cs, i', env', x) = g ctx i env
in (ds, cs, i', env', f x)
instance Applicative TCM where
pure x = TCM $ \_ i env -> (mempty, mempty, i, env, x)
(<*>) = ap
instance Monad TCM where
TCM f >>= g = TCM $ \ctx i1 env1 ->
let (ds2, cs2, i2, env2, x) = f ctx i1 env1
(ds3, cs3, i3, env3, y) = runTCM (g x) ctx i2 env2
in (ds2 <> ds3, cs2 <> cs3, i3, env3, y)
class Monad m => MonadRaise m where
raise :: Severity -> Range -> String -> m ()
instance MonadRaise TCM where
raise sev rng@(Range (Pos y _) _) msg = TCM $ \(fp, source) i env ->
(pure (Diagnostic sev fp rng [] (lines source !! y) msg), mempty, i, env, ())
emit :: Constr -> TCM ()
emit c = TCM $ \_ i env -> (mempty, pure c, i, env, ())
collectConstraints :: (Constr -> Maybe b) -> TCM a -> TCM (Bag b, a)
collectConstraints predicate (TCM f) = TCM $ \ctx i env ->
let (ds, cs, i', env', x) = f ctx i env
(yes, no) = bagPartition predicate cs
in (ds, no, i', env', (yes, x))
getFullEnv :: TCM Env
getFullEnv = TCM $ \_ i env -> (mempty, mempty, i, env, env)
putFullEnv :: Env -> TCM ()
putFullEnv env = TCM $ \_ i _ -> (mempty, mempty, i, env, ())
genId :: TCM Int
genId = TCM $ \_ i env -> (mempty, mempty, i + 1, env, i)
getKind :: Name -> TCM (Maybe CKind)
getKind name = do
Env tenv _ _ <- getFullEnv
return (Map.lookup name tenv)
getType :: Name -> TCM (Maybe CType)
getType name = do
Env _ venv _ <- getFullEnv
return (Map.lookup name venv)
getInvCon :: Name -> TCM (Maybe InvCon)
getInvCon name = do
Env _ _ icenv <- getFullEnv
return (Map.lookup name icenv)
modifyTEnv :: (Map Name CKind -> Map Name CKind) -> TCM ()
modifyTEnv f = do
Env tenv venv icenv <- getFullEnv
putFullEnv (Env (f tenv) venv icenv)
modifyVEnv :: (Map Name CType -> Map Name CType) -> TCM ()
modifyVEnv f = do
Env tenv venv icenv <- getFullEnv
putFullEnv (Env tenv (f venv) icenv)
modifyICEnv :: (Map Name InvCon -> Map Name InvCon) -> TCM ()
modifyICEnv f = do
Env tenv venv icenv <- getFullEnv
putFullEnv (Env tenv venv (f icenv))
scopeTEnv :: TCM a -> TCM a
scopeTEnv m = do
Env origtenv _ _ <- getFullEnv
res <- m
modifyTEnv (\_ -> origtenv)
return res
scopeVEnv :: TCM a -> TCM a
scopeVEnv m = do
Env _ origvenv _ <- getFullEnv
res <- m
modifyVEnv (\_ -> origvenv)
return res
genKUniVar :: TCM CKind
genKUniVar = KExt () . KUniVar <$> genId
genUniVar :: CKind -> TCM CType
genUniVar k = TExt k . TUniVar <$> genId
getKind' :: Range -> Name -> TCM CKind
getKind' rng name = getKind name >>= \case
Nothing -> do
raise SError rng $ "Type not in scope: " ++ pretty name
genKUniVar
Just k -> return k
getType' :: Range -> Name -> TCM CType
getType' rng name = getType name >>= \case
Nothing -> do
raise SError rng $ "Variable not in scope: " ++ pretty name
genUniVar (KType ())
Just k -> return k
tcTop :: PProgram -> TCM TProgram
tcTop prog = do
(cs, prog') <- collectConstraints Just (tcProgram prog)
(subK, subT) <- solveConstrs cs
return $ finaliseProg (substProg subK subT prog')
tcProgram :: PProgram -> TCM CProgram
tcProgram (Program ddefs1 fdefs1) = do
-- kind-check data definitions and collect ensuing kind constraints
(kconstrs, ddefs2) <- collectConstraints isCEqK $ do
ks <- mapM prepareDataDef ddefs1
zipWithM kcDataDef ks ddefs1
-- solve the kind constraints and finalise data types
kinduvars <- solveKindVars kconstrs
let ddefs3 = map (substDdef kinduvars mempty) ddefs2
modifyTEnv (Map.map (substKind kinduvars))
-- generate inverse constructors for all data types
forM_ ddefs3 $ \ddef ->
modifyICEnv (Map.fromList (generateInvCons ddef) <>)
traceM (unlines (map pretty ddefs3))
-- check the function definitions
fdefs2 <- tcFunDefBlock fdefs1
return (Program ddefs3 fdefs2)
-- Bring data type name in scope with a kind of the specified arity
prepareDataDef :: PDataDef -> TCM (CKind, [CKind])
prepareDataDef (DataDef _ name params _) = do
parkinds <- mapM (\_ -> genKUniVar) params
let k = foldr (KFun ()) (KType ()) parkinds
modifyTEnv (Map.insert name k)
return (k, parkinds)
-- Assumes that the kind of the name itself has already been registered with
-- the correct arity (this is done by prepareDataDef).
kcDataDef :: (CKind, [CKind]) -> PDataDef -> TCM CDataDef
kcDataDef (kd, parkinds) (DataDef _ name params cons) = do
-- ensure unicity of type param names
params' <-
let prenames = Set.fromList (map snd params)
namegen = filter (`Set.notMember` prenames) [Name ('t' : show i) | i <- [1::Int ..]]
in forM (zip3 params (inits (map snd params)) namegen) $ \((rng, pname), previous, replname) ->
if pname `elem` previous
then do raise SError rng $ "Duplicate type parameter: " ++ pretty pname
return replname
else return pname
-- kind-check the constructors
cons' <- scopeTEnv $ do
modifyTEnv (Map.fromList (zip params' parkinds) <>)
forM cons $ \(cname, fieldtys) -> do
fieldtys' <- mapM (kcType KCTMNormal (Just (KType ()))) fieldtys
return (cname, fieldtys')
return (DataDef kd name (zip parkinds params') cons')
generateInvCons :: CDataDef -> [(Name, InvCon)]
generateInvCons (DataDef k tname params cons) =
let tyvars = Map.fromList (map swap params)
resty = TApp (KType ()) (TCon k tname) (map (uncurry TVar) params)
in [(cname, InvCon tyvars resty fields) | (cname, fields) <- cons]
promoteDownK :: Maybe CKind -> TCM CKind
promoteDownK Nothing = genKUniVar
promoteDownK (Just k) = return k
downEqK :: Range -> Maybe CKind -> CKind -> TCM ()
downEqK _ Nothing _ = return ()
downEqK rng (Just k1) k2 = emit $ CEqK k1 k2 rng
data KCTypeMode ext ret where
-- | Kind-check a normal type: out-of-scope type variables are reported as errors.
KCTMNormal :: KCTypeMode () CType
-- | Kind-check an open type: out-of-scope type variables are returned. This
-- is used to check function type signatures, which may have an implicit
-- forall telescope at the head.
KCTMOpen :: KCTypeMode (MMap Name (First CKind)) (CType, Map Name CKind)
-- | Given (maybe) the expected kind of this type, and a type, check it for
-- kind-correctness.
kcType :: forall ext ret. KCTypeMode ext ret -> Maybe CKind -> PType -> TCM ret
kcType KCTMNormal mdown t = snd <$> kcType' KCTMNormal mdown t
kcType KCTMOpen mdown t = second (\(MMap m) -> Map.map getFirst m) . swap <$> kcType' KCTMOpen mdown t
-- | Given (maybe) the expected kind of this type, and a type, check it for
-- kind-correctness.
kcType' :: forall ext ret. Monoid ext => KCTypeMode ext ret -> Maybe CKind -> PType -> TCM (ext, CType)
kcType' mode mdown = \case
TApp rng t ts -> do
(ext1, t') <- kcType' mode Nothing t
(ext2, ts') <- sequence <$> mapM (kcType' mode Nothing) ts
retk <- promoteDownK mdown
let expected = foldr (KFun ()) retk (map extOf ts')
emit $ CEqK (extOf t') expected rng
return (ext1 <> ext2, TApp retk t' ts')
TTup rng ts -> do
(ext, ts') <- sequence <$> mapM (kcType' mode (Just (KType ()))) ts
forM_ (zip (map extOf ts) ts') $ \(trng, ct) ->
emit $ CEqK (extOf ct) (KType ()) trng
downEqK rng mdown (KType ())
return (ext, TTup (KType ()) ts')
TList rng t -> do
(ext, t') <- kcType' mode (Just (KType ())) t
emit $ CEqK (extOf t') (KType ()) (extOf t)
downEqK rng mdown (KType ())
return (ext, TList (KType ()) t')
TFun rng t1 t2 -> do
(ext1, t1') <- kcType' mode (Just (KType ())) t1
(ext2, t2') <- kcType' mode (Just (KType ())) t2
emit $ CEqK (extOf t1') (KType ()) (extOf t1)
emit $ CEqK (extOf t2') (KType ()) (extOf t2)
downEqK rng mdown (KType ())
return (ext1 <> ext2, TFun (KType ()) t1' t2')
TCon rng n -> do
k <- getKind' rng n
downEqK rng mdown k
return (mempty, TCon k n)
TVar rng n -> do
k <- getKind' rng n
downEqK rng mdown k
return (case mode of KCTMNormal -> ()
KCTMOpen -> MMap.singleton n (pure k)
,TVar k n)
TForall rng n t -> do -- implicit forall
k1 <- genKUniVar
k2 <- genKUniVar
downEqK rng mdown k2
(ext, t') <- scopeTEnv $ do
modifyTEnv (Map.insert n k1)
kcType' mode (Just k2) t
return (ext, TForall k2 n t') -- not 'k1 -> k2' because the forall is implicit
tcFunDefBlock :: [PFunDef] -> TCM [CFunDef]
tcFunDefBlock fdefs = do
-- generate preliminary unification variables for the functions' types
bound <- mapM (\(FunDef _ n _ _) -> (n,) <$> genUniVar (KType ())) fdefs
defs' <- forM fdefs $ \def@(FunDef _ name _ _) ->
scopeVEnv $ do
modifyVEnv (Map.fromList [(n, t) | (n, t) <- bound, n /= name] <>)
tcFunDef def
-- take the actual found types for typechecking the body (and link them
-- to the variables generated above)
let bound2 = map (\(FunDef ty n _ _) -> (n, ty)) defs'
forM_ (zip3 fdefs bound bound2) $ \(fdef, (_, tvar), (_, ty)) ->
emit $ CEq ty tvar (extOf fdef) -- which is expected/observed? which range? /shrug/
return defs'
tcFunDef :: PFunDef -> TCM CFunDef
tcFunDef (FunDef rng name msig eqs) = do
when (not $ allEq (fmap (length . funeqPats) eqs)) $
raise SError rng "Function equations have differing numbers of arguments"
typ <- case msig of
TypeSig sig -> do
(typ, freetvars) <- kcType KCTMOpen (Just (KType ())) sig
TODO -- We need to check that these free type variables do not escape.
-- Perhaps with levels on unification variables? Associate a level
-- to a generated uvar, and increment the global level counter when
-- passing below a forall.
-- But how do we deal with functions without a type signature
-- anyway? We should be able to infer a polymorphic type for them.
return $ foldr (\(n, k) -> TForall k n) typ (Map.assocs freetvars)
TypeSigExt NoTypeSig -> genUniVar (KType ())
eqs' <- scopeVEnv $ do
modifyVEnv (Map.insert name typ) -- allow function to be recursive
mapM (tcFunEq typ) eqs
return (FunDef typ name (TypeSig typ) eqs')
tcFunEq :: CType -> PFunEq -> TCM CFunEq
tcFunEq down (FunEq rng name pats rhs) = do
-- getFullEnv >>= \env -> traceM $ "[tcFunEq] Env = " ++ show env
(argtys, rhsty) <- unfoldFunTy rng (length pats) down
scopeVEnv $ do
pats' <- zipWithM tcPattern argtys pats
rhs' <- tcRHS rhsty rhs
return (FunEq () name pats' rhs')
-- | Brings the bound variables in scope
tcPattern :: CType -> PPattern -> TCM CPattern
tcPattern down = \case
PWildcard _ -> return $ PWildcard down
PVar _ n -> modifyVEnv (Map.insert n down) >> return (PVar down n)
PAs _ n p -> modifyVEnv (Map.insert n down) >> tcPattern down p
PCon rng n ps ->
getInvCon n >>= \case
Just (InvCon tyvars match fields) -> do
unisub <- mapM genUniVar tyvars -- substitution for the universally quantified variables
let match' = substType mempty mempty unisub match
fields' = map (substType mempty mempty unisub) fields
emit $ CEq down match' rng
PCon match' n <$> zipWithM tcPattern fields' ps
Nothing -> do
raise SError rng $ "Constructor not in scope: " ++ pretty n
return (PWildcard down)
POp rng p1 op p2 ->
case op of
OCons -> do
eltty <- genUniVar (KType ())
let listty = TList (KType ()) eltty
emit $ CEq down listty rng
p1' <- tcPattern eltty p1
p2' <- tcPattern listty p2
return (POp listty p1' OCons p2')
_ -> do
raise SError rng $ "Operator is not a constructor: " ++ pretty op
return (PWildcard down)
PList rng ps -> do
eltty <- genUniVar (KType ())
let listty = TList (KType ()) eltty
emit $ CEq down listty rng
PList listty <$> mapM (tcPattern eltty) ps
PTup rng ps -> do
ts <- mapM (\_ -> genUniVar (KType ())) ps
emit $ CEq down (TTup (KType ()) ts) rng
PTup (TTup (KType ()) ts) <$> zipWithM tcPattern ts ps
tcRHS :: CType -> PRHS -> TCM CRHS
tcRHS _ (Guarded _ _) = error "typecheck: Guards not yet supported"
tcRHS down (Plain _ e) = do
e' <- tcExpr down e
return $ Plain (extOf e') e'
tcExpr :: CType -> PExpr -> TCM CExpr
tcExpr down = \case
ELit rng lit -> do
let ty = case lit of
LInt{} -> TCon (KType ()) (Name "Int")
LFloat{} -> TCon (KType ()) (Name "Double")
LChar{} -> TCon (KType ()) (Name "Char")
LString{} -> TList (KType ()) (TCon (KType ()) (Name "Char"))
emit $ CEq down ty rng
return (ELit ty lit)
EVar rng n -> do
ty <- getType' rng n
emit $ CEq down ty rng
return $ EVar ty n
ECon rng n -> do
ty <- getType' rng n
emit $ CEq down ty rng
return $ EVar ty n
EList rng es -> do
eltty <- genUniVar (KType ())
let listty = TList (KType ()) eltty
emit $ CEq down listty rng
EList listty <$> mapM (tcExpr listty) es
ETup rng es -> do
ts <- mapM (\_ -> genUniVar (KType ())) es
emit $ CEq down (TTup (KType ()) ts) rng
ETup (TTup (KType ()) ts) <$> zipWithM tcExpr ts es
EApp _ e1 es -> do
argtys <- mapM (\_ -> genUniVar (KType ())) es
let funty = foldr (TFun (KType ())) down argtys
EApp funty <$> tcExpr funty e1
<*> zipWithM tcExpr argtys es
-- TODO: these types are way too monomorphic and in any case these
-- ~operators~ functions should not be built-in
EOp rng e1 op e2 -> do
let int = TCon (KType ()) (Name "Int")
bool = TCon (KType ()) (Name "Bool")
(rty, aty1, aty2) <- case op of
OAdd -> return (int, int, int)
OSub -> return (int, int, int)
OMul -> return (int, int, int)
ODiv -> return (int, int, int)
OMod -> return (int, int, int)
OEqu -> return (bool, int, int)
OPow -> return (int, int, int)
OCons -> do
eltty <- genUniVar (KType ())
let listty = TList (KType ()) eltty
return (listty, eltty, listty)
emit $ CEq down rty rng
e1' <- tcExpr aty1 e1
e2' <- tcExpr aty2 e2
return (EOp rty e1' op e2')
EIf _ e1 e2 e3 -> do
e1' <- tcExpr (TCon (KType ()) (Name "Bool")) e1
e2' <- tcExpr down e2
e3' <- tcExpr down e3
return (EIf down e1' e2' e3')
ECase _ e1 alts -> do
ty <- genUniVar (KType ())
e1' <- tcExpr ty e1
alts' <- forM alts $ \(pat, rhs) ->
scopeVEnv $
(,) <$> tcPattern ty pat <*> tcRHS down rhs
return $ ECase down e1' alts'
ELet _ defs body -> do
defs' <- tcFunDefBlock defs
let bound2 = map (\(FunDef ty n _ _) -> (n, ty)) defs'
scopeVEnv $ do
modifyVEnv (Map.fromList bound2 <>)
body' <- tcExpr down body
return $ ELet down defs' body'
EError _ -> return $ EError down
unfoldFunTy :: Range -> Int -> CType -> TCM ([CType], CType)
unfoldFunTy _ n t | n <= 0 = return ([], t)
unfoldFunTy rng n (TFun _ t1 t2) = do
(params, core) <- unfoldFunTy rng (n - 1) t2
return (t1 : params, core)
unfoldFunTy rng n t = do
vars <- replicateM n (genUniVar (KType ()))
core <- genUniVar (KType ())
let expected = foldr (TFun (KType ())) core vars
emit $ CEq expected t rng
return (vars, core)
solveConstrs :: MonadRaise m => Bag Constr -> m (Map Int CKind, Map Int CType)
solveConstrs constrs = do
let (tcs, kcs) = partitionConstrs constrs
subK <- solveKindVars kcs
subT <- solveTypeVars tcs
let subT' = Map.map (substType subK mempty mempty) subT
return (subK, subT')
solveKindVars :: MonadRaise m => Bag (CKind, CKind, Range) -> m (Map Int CKind)
solveKindVars cs = do
let (asg, errs) =
solveConstraints
reduce
(foldMap pure . kindUniVars)
substKind
(\case KExt () (KUniVar v) -> Just v
_ -> Nothing)
kindSize
(toList cs)
forM_ errs $ \case
UEUnequal k1 k2 rng ->
raise SError rng $
"Kind mismatch:\n\
\- " ++ pretty k1 ++ "\n\
\- " ++ pretty k2
UERecursive uvar k rng ->
raise SError rng $
"Kind cannot be recursive: " ++ pretty (KExt () (KUniVar uvar)) ++ " = " ++ pretty k
-- default unconstrained kind variables to Type
let unconstrKUVars = foldMap kindUniVars (Map.elems asg) Set.\\ Map.keysSet asg
defaults = Map.fromList (map (,KType ()) (toList unconstrKUVars))
asg' = Map.map (substKind defaults) asg <> defaults
return asg'
where
reduce :: CKind -> CKind -> Range -> (Bag (Int, CKind, Range), Bag (CKind, CKind, Range))
reduce lhs rhs rng = case (lhs, rhs) of
-- unification variables produce constraints on a unification variable
(KExt () (KUniVar i), KExt () (KUniVar j)) | i == j -> mempty
(KExt () (KUniVar i), k ) -> (pure (i, k, rng), mempty)
(k , KExt () (KUniVar i)) -> (pure (i, k, rng), mempty)
-- if lhs and rhs have equal prefixes, recurse
(KType () , KType () ) -> mempty
(KFun () a b, KFun () c d) -> reduce a c rng <> reduce b d rng
-- otherwise, this is a kind mismatch
(k1, k2) -> (mempty, pure (k1, k2, rng))
kindSize :: CKind -> Int
kindSize KType{} = 1
kindSize (KFun () a b) = 1 + kindSize a + kindSize b
kindSize (KExt () KUniVar{}) = 2
solveTypeVars :: MonadRaise m => Bag (CType, CType, Range) -> m (Map Int CType)
solveTypeVars cs = do
let (asg, errs) =
solveConstraints
reduce
(foldMap pure . typeUniVars)
(\m -> substType mempty m mempty)
(\case TExt _ (TUniVar v) -> Just v
_ -> Nothing)
typeSize
(toList cs)
forM_ errs $ \case
UEUnequal t1 t2 rng ->
raise SError rng $
"Type mismatch:\n\
\- " ++ pretty t1 ++ "\n\
\- " ++ pretty t2
UERecursive uvar t rng ->
raise SError rng $
"Type cannot be recursive: " ++ pretty (TExt (extOf t) (TUniVar uvar)) ++ " = " ++ pretty t
return asg
where
reduce :: CType -> CType -> Range -> (Bag (Int, CType, Range), Bag (CType, CType, Range))
reduce lhs rhs rng = case (lhs, rhs) of
-- unification variables produce constraints on a unification variable
(TExt _ (TUniVar i), TExt _ (TUniVar j)) | i == j -> mempty
(TExt _ (TUniVar i), t ) -> (pure (i, t, rng), mempty)
(t , TExt _ (TUniVar i)) -> (pure (i, t, rng), mempty)
-- if lhs and rhs have equal prefixes, recurse
(TApp _ t ts, TApp _ t' ts') -> reduce t t' rng <> foldMap (\(a, b) -> reduce a b rng) (zip ts ts')
(TTup _ ts, TTup _ ts') -> foldMap (\(a, b) -> reduce a b rng) (zip ts ts')
(TList _ t, TList _ t') -> reduce t t' rng
(TFun _ t1 t2, TFun _ t1' t2') -> reduce t1 t1' rng <> reduce t2 t2' rng
(TCon _ n1, TCon _ n2) | n1 == n2 -> mempty
(TVar _ n1, TVar _ n2) | n1 == n2 -> mempty
(TForall _ n1 t1, TForall k n2 t2) ->
reduce t1 (substType mempty mempty (Map.singleton n2 (TVar k n1)) t2) rng
-- otherwise, this is a kind mismatch
(k1, k2) -> (mempty, pure (k1, k2, rng))
typeSize :: CType -> Int
typeSize (TApp _ t ts) = typeSize t + sum (map typeSize ts)
typeSize (TTup _ ts) = sum (map typeSize ts)
typeSize (TList _ t) = 1 + typeSize t
typeSize (TFun _ t1 t2) = typeSize t1 + typeSize t2
typeSize (TCon _ _) = 1
typeSize (TVar _ _) = 1
typeSize (TForall _ _ t) = 1 + typeSize t
typeSize (TExt _ TUniVar{}) = 2
partitionConstrs :: Foldable t => t Constr -> (Bag (CType, CType, Range), Bag (CKind, CKind, Range))
partitionConstrs = foldMap $ \case CEq t1 t2 r -> (pure (t1, t2, r), mempty)
CEqK k1 k2 r -> (mempty, pure (k1, k2, r))
-------------------- SUBSTITUTION FUNCTIONS --------------------
-- These take some of:
-- - an instantiation map for kind unification variables (Map Int CKind)
-- - an instantiation map for type unification variables (Map Int CType)
-- - an instantiation map for type variables (Map Name CType)
substProg :: HasCallStack
=> Map Int CKind -> Map Int CType -> CProgram -> CProgram
substProg mk mt (Program ds fs) = Program (map (substDdef mk mt) ds) (map (substFdef mk mt) fs)
substDdef :: HasCallStack
=> Map Int CKind -> Map Int CType -> CDataDef -> CDataDef
substDdef mk mt (DataDef k name pars cons) =
DataDef (substKind mk k) name
(map (first (substKind mk)) pars)
(map (second (map (substType mk mt mempty))) cons)
substFdef :: HasCallStack
=> Map Int CKind -> Map Int CType -> CFunDef -> CFunDef
substFdef mk mt (FunDef t n (TypeSig sig) eqs) =
FunDef (substType mk mt mempty t) n
(TypeSig (substType mk mt mempty sig))
(fmap (substFunEq mk mt) eqs)
substFunEq :: HasCallStack
=> Map Int CKind -> Map Int CType -> CFunEq -> CFunEq
substFunEq mk mt (FunEq () n ps rhs) =
FunEq () n
(map (substPattern mk mt) ps)
(substRHS mk mt rhs)
substRHS :: HasCallStack
=> Map Int CKind -> Map Int CType -> CRHS -> CRHS
substRHS _ _ (Guarded _ _) = error "typecheck: guards unsupported"
substRHS mk mt (Plain t e) = Plain (substType mk mt mempty t) (substExpr mk mt e)
substPattern :: HasCallStack
=> Map Int CKind -> Map Int CType -> CPattern -> CPattern
substPattern mk mt = go
where
go (PWildcard t) = PWildcard (goType t)
go (PVar t n) = PVar (goType t) n
go (PAs t n p) = PAs (goType t) n (go p)
go (PCon t n ps) = PCon (goType t) n (map go ps)
go (POp t p1 op p2) = POp (goType t) (go p1) op (go p2)
go (PList t ps) = PList (goType t) (map go ps)
go (PTup t ps) = PTup (goType t) (map go ps)
goType = substType mk mt mempty
substExpr :: HasCallStack
=> Map Int CKind -> Map Int CType -> CExpr -> CExpr
substExpr mk mt = go
where
go (ELit t lit) = ELit (goType t) lit
go (EVar t n) = EVar (goType t) n
go (ECon t n) = ECon (goType t) n
go (EList t es) = EList (goType t) (map go es)
go (ETup t es) = ETup (goType t) (map go es)
go (EApp t e1 es) = EApp (goType t) (go e1) (map go es)
go (EOp t e1 op e2) = EOp (goType t) (go e1) op (go e2)
go (EIf t e1 e2 e3) = EIf (goType t) (go e1) (go e2) (go e3)
go (ECase t e1 alts) = ECase (goType t) (go e1) (map (bimap (substPattern mk mt) (substRHS mk mt)) alts)
go (ELet t defs body) = ELet (goType t) (map (substFdef mk mt) defs) (go body)
go (EError t) = EError (goType t)
goType = substType mk mt mempty
substType :: HasCallStack
=> Map Int CKind -> Map Int CType -> Map Name CType -> CType -> CType
substType mk mt mtv = go
where
go (TApp k t ts) = TApp (goKind k) (go t) (map go ts)
go (TTup k ts) = TTup (goKind k) (map go ts)
go (TList k t) = TList (goKind k) (go t)
go (TFun k t1 t2) = TFun (goKind k) (go t1) (go t2)
go (TCon k n) = TCon (goKind k) n
go (TVar k n) = fromMaybe (TVar (goKind k) n) (Map.lookup n mtv)
go (TForall k n t) = TForall (goKind k) n (go t)
go (TExt k (TUniVar v)) = fromMaybe (TExt (goKind k) (TUniVar v)) (Map.lookup v mt)
goKind = substKind mk
substKind :: HasCallStack
=> Map Int CKind -> CKind -> CKind
substKind m = \case
KType () -> KType ()
KFun () k1 k2 -> KFun () (substKind m k1) (substKind m k2)
k@(KExt () (KUniVar v)) -> fromMaybe k (Map.lookup v m)
-------------------- FINALISATION FUNCTIONS --------------------
-- These report free type unification variables.
-- TODO the finalise* functions
typeUniVars :: CType -> Set Int
typeUniVars = \case
TApp _ t ts -> typeUniVars t <> foldMap typeUniVars ts
TTup _ ts -> foldMap typeUniVars ts
TList _ t -> typeUniVars t
TFun _ t1 t2 -> typeUniVars t1 <> typeUniVars t2
TCon _ _ -> mempty
TVar _ _ -> mempty
TForall _ _ t -> typeUniVars t
TExt _ (TUniVar v) -> Set.singleton v
kindUniVars :: CKind -> Set Int
kindUniVars = \case
KType{} -> mempty
KFun () a b -> kindUniVars a <> kindUniVars b
KExt () (KUniVar v) -> Set.singleton v
allEq :: (Eq a, Foldable t) => t a -> Bool
allEq l = case toList l of
[] -> True
x : xs -> all (== x) xs
funeqPats :: FunEq t -> [Pattern t]
funeqPats (FunEq _ _ pats _) = pats
isCEqK :: Constr -> Maybe (CKind, CKind, Range)
isCEqK (CEqK k1 k2 rng) = Just (k1, k2, rng)
isCEqK _ = Nothing
foldMapM :: (Applicative f, Monoid m, Foldable t) => (a -> f m) -> t a -> f m
foldMapM f = getAp . foldMap (Ap . f)
|