1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
|
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE EmptyDataDeriving #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TupleSections #-}
module HSVIS.Typecheck where
import Control.Monad
import Data.Bifunctor (first)
import Data.Foldable (toList)
import Data.Map.Strict (Map)
import Data.Maybe (fromMaybe)
import Data.Monoid (Ap(..))
import qualified Data.Map.Strict as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Debug.Trace
import Data.Bag
import Data.List.NonEmpty.Util
import HSVIS.AST
import HSVIS.Parser
import HSVIS.Diagnostic
import HSVIS.Pretty
import HSVIS.Typecheck.Solve
data StageTC
type instance X DataDef StageTC = ()
type instance X FunDef StageTC = CType
type instance X FunEq StageTC = CType
type instance X Kind StageTC = ()
type instance X Type StageTC = CKind
type instance X Pattern StageTC = CType
type instance X RHS StageTC = CType
type instance X Expr StageTC = CType
data instance E Type StageTC = TUniVar Int deriving (Show)
data instance E Kind StageTC = KUniVar Int deriving (Show, Eq, Ord)
data instance E TypeSig StageTC deriving (Show)
type CProgram = Program StageTC
type CDataDef = DataDef StageTC
type CFunDef = FunDef StageTC
type CFunEq = FunEq StageTC
type CKind = Kind StageTC
type CType = Type StageTC
type CPattern = Pattern StageTC
type CRHS = RHS StageTC
type CExpr = Expr StageTC
data StageTyped
type instance X DataDef StageTyped = TType
type instance X FunDef StageTyped = TType
type instance X FunEq StageTyped = TType
type instance X Kind StageTyped = ()
type instance X Type StageTyped = TKind
type instance X Pattern StageTyped = TType
type instance X RHS StageTyped = TType
type instance X Expr StageTyped = TType
data instance E Type StageTyped deriving (Show)
data instance E Kind StageTyped deriving (Show)
data instance E TypeSig StageTyped deriving (Show)
type TProgram = Program StageTyped
type TDataDef = DataDef StageTyped
type TFunDef = FunDef StageTyped
type TFunEq = FunEq StageTyped
type TKind = Kind StageTyped
type TType = Type StageTyped
type TPattern = Pattern StageTyped
type TRHS = RHS StageTyped
type TExpr = Expr StageTyped
instance Pretty (E Kind StageTC) where
prettysPrec _ (KUniVar n) = showString ("?k" ++ show n)
typecheck :: FilePath -> String -> PProgram -> ([Diagnostic], TProgram)
typecheck fp source prog =
let (ds1, cs, _, _, progtc) =
runTCM (tcProgram prog) (fp, source) 1 (Env mempty mempty)
(ds2, sub) = solveConstrs cs
in (toList (ds1 <> ds2), substProg sub progtc)
data Constr
-- Equality constraints: "left" must be equal to "right" because of the thing
-- at the given range. "left" is the expected thing; "right" is the observed
-- thing.
= CEq CType CType Range
| CEqK CKind CKind Range
deriving (Show)
data Env = Env (Map Name CKind) (Map Name CType)
deriving (Show)
newtype TCM a = TCM {
runTCM :: (FilePath, String) -- ^ reader context: file and file contents
-> Int -- ^ state: next id to generate
-> Env -- ^ state: type and value environment
-> (Bag Diagnostic -- ^ writer: diagnostics
,Bag Constr -- ^ writer: constraints
,Int, Env, a)
}
instance Functor TCM where
fmap f (TCM g) = TCM $ \ctx i env ->
let (ds, cs, i', env', x) = g ctx i env
in (ds, cs, i', env', f x)
instance Applicative TCM where
pure x = TCM $ \_ i env -> (mempty, mempty, i, env, x)
(<*>) = ap
instance Monad TCM where
TCM f >>= g = TCM $ \ctx i1 env1 ->
let (ds2, cs2, i2, env2, x) = f ctx i1 env1
(ds3, cs3, i3, env3, y) = runTCM (g x) ctx i2 env2
in (ds2 <> ds3, cs2 <> cs3, i3, env3, y)
raise :: Range -> String -> TCM ()
raise rng@(Range (Pos y _) _) msg = TCM $ \(fp, source) i env ->
(pure (Diagnostic fp rng [] (lines source !! y) msg), mempty, i, env, ())
emit :: Constr -> TCM ()
emit c = TCM $ \_ i env -> (mempty, pure c, i, env, ())
collectConstraints :: (Constr -> Maybe b) -> TCM a -> TCM (Bag b, a)
collectConstraints predicate (TCM f) = TCM $ \ctx i env ->
let (ds, cs, i', env', x) = f ctx i env
(yes, no) = bagPartition predicate cs
in (ds, no, i', env', (yes, x))
getFullEnv :: TCM Env
getFullEnv = TCM $ \_ i env -> (mempty, mempty, i, env, env)
putFullEnv :: Env -> TCM ()
putFullEnv env = TCM $ \_ i _ -> (mempty, mempty, i, env, ())
genId :: TCM Int
genId = TCM $ \_ i env -> (mempty, mempty, i, env, i)
getKind :: Name -> TCM (Maybe CKind)
getKind name = do
Env tenv _ <- getFullEnv
return (Map.lookup name tenv)
getType :: Name -> TCM (Maybe CType)
getType name = do
Env _ venv <- getFullEnv
return (Map.lookup name venv)
modifyTEnv :: (Map Name CKind -> Map Name CKind) -> TCM ()
modifyTEnv f = do
Env tenv venv <- getFullEnv
putFullEnv (Env (f tenv) venv)
modifyVEnv :: (Map Name CType -> Map Name CType) -> TCM ()
modifyVEnv f = do
Env tenv venv <- getFullEnv
putFullEnv (Env tenv (f venv))
scopeTEnv :: TCM a -> TCM a
scopeTEnv m = do
Env origtenv _ <- getFullEnv
res <- m
modifyTEnv (\_ -> origtenv)
return res
scopeVEnv :: TCM a -> TCM a
scopeVEnv m = do
Env _ origvenv <- getFullEnv
res <- m
modifyVEnv (\_ -> origvenv)
return res
genKUniVar :: TCM CKind
genKUniVar = KExt () . KUniVar <$> genId
genUniVar :: CKind -> TCM CType
genUniVar k = TExt k . TUniVar <$> genId
getKind' :: Range -> Name -> TCM CKind
getKind' rng name = getKind name >>= \case
Nothing -> do
raise rng $ "Type not in scope: " ++ pretty name
genKUniVar
Just k -> return k
getType' :: Range -> Name -> TCM CType
getType' rng name = getType name >>= \case
Nothing -> do
raise rng $ "Variable not in scope: " ++ pretty name
genUniVar (KType ())
Just k -> return k
tcProgram :: PProgram -> TCM CProgram
tcProgram (Program ddefs fdefs) = do
(kconstrs, ddefs') <- collectConstraints isCEqK $ do
mapM_ prepareDataDef ddefs
mapM tcDataDef ddefs
solveKindVars kconstrs
fdefs' <- mapM tcFunDef fdefs
return (Program ddefs' fdefs')
prepareDataDef :: PDataDef -> TCM ()
prepareDataDef (DataDef _ name params _) = do
parkinds <- mapM (\_ -> genKUniVar) params
let k = foldr (KFun ()) (KType ()) parkinds
modifyTEnv (Map.insert name k)
-- Assumes that the kind of the name itself has already been registered with
-- the correct arity (this is doen by prepareDataDef).
tcDataDef :: PDataDef -> TCM CDataDef
tcDataDef (DataDef rng name params cons) = do
kd <- getKind' rng name
let (pkinds, kret) = splitKind kd
-- sanity checking; would be nicer to store these in prepareDataDef already
when (length pkinds /= length params) $ error "tcDataDef: Invalid param kind list length"
case kret of Right () -> return ()
_ -> error "tcDataDef: Invalid ret kind"
cons' <- scopeTEnv $ do
modifyTEnv (Map.fromList (zip (map snd params) pkinds) <>)
mapM (\(cname, ty) -> (cname,) <$> mapM kcType ty) cons
return (DataDef () name (zip pkinds (map snd params)) cons')
kcType :: PType -> TCM CType
kcType = \case
TApp rng t ts -> do
t' <- kcType t
ts' <- mapM kcType ts
retk <- genKUniVar
let expected = foldr (KFun ()) retk (map extOf ts')
emit $ CEqK (extOf t') expected rng
return (TApp retk t' ts')
TTup _ ts -> do
ts' <- mapM kcType ts
forM_ (zip (map extOf ts) ts') $ \(trng, ct) ->
emit $ CEqK (extOf ct) (KType ()) trng
return (TTup (KType ()) ts')
TList _ t -> do
t' <- kcType t
emit $ CEqK (extOf t') (KType ()) (extOf t)
return (TList (KType ()) t')
TFun _ t1 t2 -> do
t1' <- kcType t1
t2' <- kcType t2
emit $ CEqK (extOf t1') (KType ()) (extOf t1)
emit $ CEqK (extOf t2') (KType ()) (extOf t2)
return (TFun (KType ()) t1' t2')
TCon rng n -> TCon <$> getKind' rng n <*> pure n
TVar rng n -> TVar <$> getKind' rng n <*> pure n
tcFunDef :: PFunDef -> TCM CFunDef
tcFunDef (FunDef _ name msig eqs) = do
when (not $ allEq (fmap (length . funeqPats) eqs)) $
raise (sconcatne (fmap extOf eqs)) "Function equations have differing numbers of arguments"
typ <- case msig of
TypeSig sig -> kcType sig
TypeSigExt NoTypeSig -> genUniVar (KType ())
eqs' <- mapM (tcFunEq typ) eqs
return (FunDef typ name (TypeSig typ) eqs')
tcFunEq :: CType -> PFunEq -> TCM CFunEq
tcFunEq = error "tcFunEq"
newtype SolveM v t m a = SolveM (Map v (Bag t) -> Map v t -> m (a, Map v (Bag t), Map v t))
instance Monad m => Functor (SolveM v t m) where
fmap f (SolveM g) = SolveM $ \m r -> do (x, m', r') <- g m r
return (f x, m', r')
instance Monad m => Applicative (SolveM v t m) where
pure x = SolveM $ \m r -> return (x, m, r)
(<*>) = ap
instance Monad m => Monad (SolveM v t m) where
SolveM f >>= g = SolveM $ \m r -> do (x, m1, r1) <- f m r
let SolveM h = g x
h m1 r1
solvemStateGet :: Monad m => SolveM v t m (Map v (Bag t))
solvemStateGet = SolveM $ \m r -> return (m, m, r)
solvemStateUpdate :: Monad m => (Map v (Bag t) -> Map v (Bag t)) -> SolveM v t m ()
solvemStateUpdate f = SolveM $ \m r -> return ((), f m, r)
solvemLogUpdate :: Monad m => (Map v t -> Map v t) -> SolveM v t m ()
solvemLogUpdate f = SolveM $ \m r -> return ((), m, f r)
solvemStateVars :: Monad m => SolveM v t m [v]
solvemStateVars = Map.keys <$> solvemStateGet
solvemStateRHS :: (Ord v, Monad m) => v -> SolveM v t m (Bag t)
solvemStateRHS v = fromMaybe mempty . Map.lookup v <$> solvemStateGet
solvemStateSet :: (Ord v, Monad m) => v -> Bag t -> SolveM v t m ()
solvemStateSet v b = solvemStateUpdate (Map.insert v b)
solvemLogEq :: (Ord v, Monad m) => v -> t -> SolveM v t m ()
solvemLogEq v t = solvemLogUpdate (Map.insert v t)
solveKindVars :: Bag (CKind, CKind, Range) -> TCM ()
solveKindVars cs = do
traceShowM cs
traceShowM $ solveConstraints
reduce
(foldMap pure . kindUniVars)
(\v repl -> substKind (Map.singleton v repl))
(\case KExt () (KUniVar v) -> Just v
_ -> Nothing)
kindSize
(map (\(a, b, _) -> (a, b)) (toList cs))
where
reduce :: CKind -> CKind -> (Bag (Int, CKind), Bag (CKind, CKind))
-- unification variables produce constraints on a unification variable
reduce (KExt () (KUniVar i)) (KExt () (KUniVar j)) | i == j = mempty
reduce (KExt () (KUniVar i)) k = (pure (i, k), mempty)
reduce k (KExt () (KUniVar i)) = (pure (i, k), mempty)
-- if lhs and rhs have equal prefixes, recurse
reduce (KType ()) (KType ()) = mempty
reduce (KFun () a b) (KFun () c d) = reduce a c <> reduce b d
-- otherwise, this is a kind mismatch
reduce k1 k2 = (mempty, pure (k1, k2))
kindSize :: CKind -> Int
kindSize KType{} = 1
kindSize (KFun () a b) = 1 + kindSize a + kindSize b
kindSize (KExt () KUniVar{}) = 1
solveConstrs :: Bag Constr -> (Bag Diagnostic, Map Name TType)
solveConstrs = error "solveConstrs"
substProg :: Map Name TType -> CProgram -> TProgram
substProg = error "substProg"
substKind :: Map Int CKind -> CKind -> CKind
substKind _ k@KType{} = k
substKind m (KFun () k1 k2) = KFun () (substKind m k1) (substKind m k2)
substKind m k@(KExt () (KUniVar v)) = fromMaybe k (Map.lookup v m)
kindUniVars :: CKind -> Set Int
kindUniVars = \case
KType{} -> mempty
KFun () a b -> kindUniVars a <> kindUniVars b
KExt () (KUniVar v) -> Set.singleton v
allEq :: (Eq a, Foldable t) => t a -> Bool
allEq l = case toList l of
[] -> True
x : xs -> all (== x) xs
funeqPats :: FunEq t -> [Pattern t]
funeqPats (FunEq _ _ pats _) = pats
splitKind :: Kind s -> ([Kind s], Either (E Kind s) (X Kind s))
splitKind (KType x) = ([], Right x)
splitKind (KFun _ k1 k2) = first (k1:) (splitKind k2)
splitKind (KExt _ e) = ([], Left e)
isCEqK :: Constr -> Maybe (CKind, CKind, Range)
isCEqK (CEqK k1 k2 rng) = Just (k1, k2, rng)
isCEqK _ = Nothing
foldMapM :: (Applicative f, Monoid m, Foldable t) => (a -> f m) -> t a -> f m
foldMapM f = getAp . foldMap (Ap . f)
|