aboutsummaryrefslogtreecommitdiff
path: root/src/HSVIS/Typecheck.hs
blob: ba853a003ba6584316ffac06f6ba66b4c29f14b0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE EmptyDataDeriving #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TupleSections #-}
module HSVIS.Typecheck where

import Control.Monad
import Data.Bifunctor (first)
import Data.Foldable (toList)
import Data.Map.Strict (Map)
import Data.Maybe (fromMaybe)
import Data.Monoid (Ap(..))
import qualified Data.Map.Strict as Map
import Data.Set (Set)
import qualified Data.Set as Set

import Debug.Trace

import Data.Bag
import Data.List.NonEmpty.Util
import HSVIS.AST
import HSVIS.Parser
import HSVIS.Diagnostic
import HSVIS.Pretty
import HSVIS.Typecheck.Solve


data StageTC

type instance X DataDef StageTC = ()
type instance X FunDef  StageTC = CType
type instance X FunEq   StageTC = CType
type instance X Kind    StageTC = ()
type instance X Type    StageTC = CKind
type instance X Pattern StageTC = CType
type instance X RHS     StageTC = CType
type instance X Expr    StageTC = CType

data instance E Type StageTC = TUniVar Int deriving (Show)
data instance E Kind StageTC = KUniVar Int deriving (Show, Eq, Ord)
data instance E TypeSig StageTC deriving (Show)

type CProgram = Program StageTC
type CDataDef = DataDef StageTC
type CFunDef  = FunDef  StageTC
type CFunEq   = FunEq   StageTC
type CKind    = Kind    StageTC
type CType    = Type    StageTC
type CPattern = Pattern StageTC
type CRHS     = RHS     StageTC
type CExpr    = Expr    StageTC

data StageTyped

type instance X DataDef StageTyped = TType
type instance X FunDef  StageTyped = TType
type instance X FunEq   StageTyped = TType
type instance X Kind    StageTyped = ()
type instance X Type    StageTyped = TKind
type instance X Pattern StageTyped = TType
type instance X RHS     StageTyped = TType
type instance X Expr    StageTyped = TType

data instance E Type StageTyped deriving (Show)
data instance E Kind StageTyped deriving (Show)
data instance E TypeSig StageTyped deriving (Show)

type TProgram = Program StageTyped
type TDataDef = DataDef StageTyped
type TFunDef  = FunDef  StageTyped
type TFunEq   = FunEq   StageTyped
type TKind    = Kind    StageTyped
type TType    = Type    StageTyped
type TPattern = Pattern StageTyped
type TRHS     = RHS     StageTyped
type TExpr    = Expr    StageTyped

instance Pretty (E Kind StageTC) where
  prettysPrec _ (KUniVar n) = showString ("?k" ++ show n)


typecheck :: FilePath -> String -> PProgram -> ([Diagnostic], TProgram)
typecheck fp source prog =
  let (ds1, cs, _, _, progtc) =
        runTCM (tcProgram prog) (fp, source) 1 (Env mempty mempty)
      (ds2, sub) = solveConstrs cs
  in (toList (ds1 <> ds2), substProg sub progtc)

data Constr
  -- Equality constraints: "left" must be equal to "right" because of the thing
  -- at the given range. "left" is the expected thing; "right" is the observed
  -- thing.
  = CEq CType CType Range
  | CEqK CKind CKind Range
  deriving (Show)

data Env = Env (Map Name CKind) (Map Name CType)
  deriving (Show)

newtype TCM a = TCM {
  runTCM :: (FilePath, String)  -- ^ reader context: file and file contents
         -> Int  -- ^ state: next id to generate
         -> Env  -- ^ state: type and value environment
         -> (Bag Diagnostic  -- ^ writer: diagnostics
            ,Bag Constr      -- ^ writer: constraints
            ,Int, Env, a)
  }

instance Functor TCM where
  fmap f (TCM g) = TCM $ \ctx i env ->
    let (ds, cs, i', env', x) = g ctx i env
    in (ds, cs, i', env', f x)

instance Applicative TCM where
  pure x = TCM $ \_ i env -> (mempty, mempty, i, env, x)
  (<*>) = ap

instance Monad TCM where
  TCM f >>= g = TCM $ \ctx i1 env1 ->
    let (ds2, cs2, i2, env2, x) = f ctx i1 env1
        (ds3, cs3, i3, env3, y) = runTCM (g x) ctx i2 env2
    in (ds2 <> ds3, cs2 <> cs3, i3, env3, y)

raise :: Range -> String -> TCM ()
raise rng@(Range (Pos y _) _) msg = TCM $ \(fp, source) i env ->
  (pure (Diagnostic fp rng [] (lines source !! y) msg), mempty, i, env, ())

emit :: Constr -> TCM ()
emit c = TCM $ \_ i env -> (mempty, pure c, i, env, ())

collectConstraints :: (Constr -> Maybe b) -> TCM a -> TCM (Bag b, a)
collectConstraints predicate (TCM f) = TCM $ \ctx i env ->
  let (ds, cs, i', env', x) = f ctx i env
      (yes, no) = bagPartition predicate cs
  in (ds, no, i', env', (yes, x))

getFullEnv :: TCM Env
getFullEnv = TCM $ \_ i env -> (mempty, mempty, i, env, env)

putFullEnv :: Env -> TCM ()
putFullEnv env = TCM $ \_ i _ -> (mempty, mempty, i, env, ())

genId :: TCM Int
genId = TCM $ \_ i env -> (mempty, mempty, i, env, i)

getKind :: Name -> TCM (Maybe CKind)
getKind name = do
  Env tenv _ <- getFullEnv
  return (Map.lookup name tenv)

getType :: Name -> TCM (Maybe CType)
getType name = do
  Env _ venv <- getFullEnv
  return (Map.lookup name venv)

modifyTEnv :: (Map Name CKind -> Map Name CKind) -> TCM ()
modifyTEnv f = do
  Env tenv venv <- getFullEnv
  putFullEnv (Env (f tenv) venv)

modifyVEnv :: (Map Name CType -> Map Name CType) -> TCM ()
modifyVEnv f = do
  Env tenv venv <- getFullEnv
  putFullEnv (Env tenv (f venv))

scopeTEnv :: TCM a -> TCM a
scopeTEnv m = do
  Env origtenv _ <- getFullEnv
  res <- m
  modifyTEnv (\_ -> origtenv)
  return res

scopeVEnv :: TCM a -> TCM a
scopeVEnv m = do
  Env _ origvenv <- getFullEnv
  res <- m
  modifyVEnv (\_ -> origvenv)
  return res

genKUniVar :: TCM CKind
genKUniVar = KExt () . KUniVar <$> genId

genUniVar :: CKind -> TCM CType
genUniVar k = TExt k . TUniVar <$> genId

getKind' :: Range -> Name -> TCM CKind
getKind' rng name = getKind name >>= \case
  Nothing -> do
    raise rng $ "Type not in scope: " ++ pretty name
    genKUniVar
  Just k -> return k

getType' :: Range -> Name -> TCM CType
getType' rng name = getType name >>= \case
  Nothing -> do
    raise rng $ "Variable not in scope: " ++ pretty name
    genUniVar (KType ())
  Just k -> return k

tcProgram :: PProgram -> TCM CProgram
tcProgram (Program ddefs fdefs) = do
  (kconstrs, ddefs') <- collectConstraints isCEqK $ do
    mapM_ prepareDataDef ddefs
    mapM tcDataDef ddefs

  solveKindVars kconstrs

  fdefs' <- mapM tcFunDef fdefs

  return (Program ddefs' fdefs')

prepareDataDef :: PDataDef -> TCM ()
prepareDataDef (DataDef _ name params _) = do
  parkinds <- mapM (\_ -> genKUniVar) params
  let k = foldr (KFun ()) (KType ()) parkinds
  modifyTEnv (Map.insert name k)

-- Assumes that the kind of the name itself has already been registered with
-- the correct arity (this is doen by prepareDataDef).
tcDataDef :: PDataDef -> TCM CDataDef
tcDataDef (DataDef rng name params cons) = do
  kd <- getKind' rng name
  let (pkinds, kret) = splitKind kd

  -- sanity checking; would be nicer to store these in prepareDataDef already
  when (length pkinds /= length params) $ error "tcDataDef: Invalid param kind list length"
  case kret of Right () -> return ()
               _ -> error "tcDataDef: Invalid ret kind"

  cons' <- scopeTEnv $ do
    modifyTEnv (Map.fromList (zip (map snd params) pkinds) <>)
    mapM (\(cname, ty) -> (cname,) <$> mapM kcType ty) cons
  return (DataDef () name (zip pkinds (map snd params)) cons')

kcType :: PType -> TCM CType
kcType = \case
  TApp rng t ts -> do
    t' <- kcType t
    ts' <- mapM kcType ts
    retk <- genKUniVar
    let expected = foldr (KFun ()) retk (map extOf ts')
    emit $ CEqK (extOf t') expected rng
    return (TApp retk t' ts')

  TTup _ ts -> do
    ts' <- mapM kcType ts
    forM_ (zip (map extOf ts) ts') $ \(trng, ct) ->
      emit $ CEqK (extOf ct) (KType ()) trng
    return (TTup (KType ()) ts')

  TList _ t -> do
    t' <- kcType t
    emit $ CEqK (extOf t') (KType ()) (extOf t)
    return (TList (KType ()) t')

  TFun _ t1 t2 -> do
    t1' <- kcType t1
    t2' <- kcType t2
    emit $ CEqK (extOf t1') (KType ()) (extOf t1)
    emit $ CEqK (extOf t2') (KType ()) (extOf t2)
    return (TFun (KType ()) t1' t2')

  TCon rng n -> TCon <$> getKind' rng n <*> pure n

  TVar rng n -> TVar <$> getKind' rng n <*> pure n

tcFunDef :: PFunDef -> TCM CFunDef
tcFunDef (FunDef _ name msig eqs) = do
  when (not $ allEq (fmap (length . funeqPats) eqs)) $
    raise (sconcatne (fmap extOf eqs)) "Function equations have differing numbers of arguments"

  typ <- case msig of
    TypeSig sig -> kcType sig
    TypeSigExt NoTypeSig -> genUniVar (KType ())

  eqs' <- mapM (tcFunEq typ) eqs

  return (FunDef typ name (TypeSig typ) eqs')

tcFunEq :: CType -> PFunEq -> TCM CFunEq
tcFunEq = error "tcFunEq"

newtype SolveM v t m a = SolveM (Map v (Bag t) -> Map v t -> m (a, Map v (Bag t), Map v t))
instance Monad m => Functor (SolveM v t m) where
  fmap f (SolveM g) = SolveM $ \m r -> do (x, m', r') <- g m r
                                          return (f x, m', r')
instance Monad m => Applicative (SolveM v t m) where
  pure x = SolveM $ \m r -> return (x, m, r)
  (<*>) = ap
instance Monad m => Monad (SolveM v t m) where
  SolveM f >>= g = SolveM $ \m r -> do (x, m1, r1) <- f m r
                                       let SolveM h = g x
                                       h m1 r1

solvemStateGet :: Monad m => SolveM v t m (Map v (Bag t))
solvemStateGet = SolveM $ \m r -> return (m, m, r)

solvemStateUpdate :: Monad m => (Map v (Bag t) -> Map v (Bag t)) -> SolveM v t m ()
solvemStateUpdate f = SolveM $ \m r -> return ((), f m, r)

solvemLogUpdate :: Monad m => (Map v t -> Map v t) -> SolveM v t m ()
solvemLogUpdate f = SolveM $ \m r -> return ((), m, f r)

solvemStateVars :: Monad m => SolveM v t m [v]
solvemStateVars = Map.keys <$> solvemStateGet

solvemStateRHS :: (Ord v, Monad m) => v -> SolveM v t m (Bag t)
solvemStateRHS v = fromMaybe mempty . Map.lookup v <$> solvemStateGet

solvemStateSet :: (Ord v, Monad m) => v -> Bag t -> SolveM v t m ()
solvemStateSet v b = solvemStateUpdate (Map.insert v b)

solvemLogEq :: (Ord v, Monad m) => v -> t -> SolveM v t m ()
solvemLogEq v t = solvemLogUpdate (Map.insert v t)

solveKindVars :: Bag (CKind, CKind, Range) -> TCM ()
solveKindVars cs = do
  traceShowM cs
  traceShowM $ solveConstraints
                 reduce
                 (foldMap pure . kindUniVars)
                 (\v repl -> substKind (Map.singleton v repl))
                 (\case KExt () (KUniVar v) -> Just v
                        _ -> Nothing)
                 kindSize
                 (map (\(a, b, _) -> (a, b)) (toList cs))
  where
    reduce :: CKind -> CKind -> (Bag (Int, CKind), Bag (CKind, CKind))
    -- unification variables produce constraints on a unification variable
    reduce (KExt () (KUniVar i)) (KExt () (KUniVar j)) | i == j = mempty
    reduce (KExt () (KUniVar i)) k = (pure (i, k), mempty)
    reduce k (KExt () (KUniVar i)) = (pure (i, k), mempty)
    -- if lhs and rhs have equal prefixes, recurse
    reduce (KType ()) (KType ()) = mempty
    reduce (KFun () a b) (KFun () c d) = reduce a c <> reduce b d
    -- otherwise, this is a kind mismatch
    reduce k1 k2 = (mempty, pure (k1, k2))

    kindSize :: CKind -> Int
    kindSize KType{} = 1
    kindSize (KFun () a b) = 1 + kindSize a + kindSize b
    kindSize (KExt () KUniVar{}) = 1

solveConstrs :: Bag Constr -> (Bag Diagnostic, Map Name TType)
solveConstrs = error "solveConstrs"

substProg :: Map Name TType -> CProgram -> TProgram
substProg = error "substProg"

substKind :: Map Int CKind -> CKind -> CKind
substKind _ k@KType{} = k
substKind m (KFun () k1 k2) = KFun () (substKind m k1) (substKind m k2)
substKind m k@(KExt () (KUniVar v)) = fromMaybe k (Map.lookup v m)

kindUniVars :: CKind -> Set Int
kindUniVars = \case
  KType{} -> mempty
  KFun () a b -> kindUniVars a <> kindUniVars b
  KExt () (KUniVar v) -> Set.singleton v

allEq :: (Eq a, Foldable t) => t a -> Bool
allEq l = case toList l of
            [] -> True
            x : xs -> all (== x) xs

funeqPats :: FunEq t -> [Pattern t]
funeqPats (FunEq _ _ pats _) = pats

splitKind :: Kind s -> ([Kind s], Either (E Kind s) (X Kind s))
splitKind (KType x) = ([], Right x)
splitKind (KFun _ k1 k2) = first (k1:) (splitKind k2)
splitKind (KExt _ e) = ([], Left e)

isCEqK :: Constr -> Maybe (CKind, CKind, Range)
isCEqK (CEqK k1 k2 rng) = Just (k1, k2, rng)
isCEqK _ = Nothing

foldMapM :: (Applicative f, Monoid m, Foldable t) => (a -> f m) -> t a -> f m
foldMapM f = getAp . foldMap (Ap . f)