aboutsummaryrefslogtreecommitdiff
path: root/src/HSVIS/Typecheck.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-01-19 22:46:39 +0100
committerTom Smeding <tom@tomsmeding.com>2025-01-19 22:46:39 +0100
commitc7619a27f841d24b5acb4c99ed486e95bd5130d8 (patch)
tree9aae2e1c9665b83090e1c3d80f71c0b9fdffea34 /src/HSVIS/Typecheck.hs
parente13b0a681108697f8b67d8c836edd54c042aad55 (diff)
Noodling on the type checker
Diffstat (limited to 'src/HSVIS/Typecheck.hs')
-rw-r--r--src/HSVIS/Typecheck.hs167
1 files changed, 98 insertions, 69 deletions
diff --git a/src/HSVIS/Typecheck.hs b/src/HSVIS/Typecheck.hs
index 1e46a99..8b642eb 100644
--- a/src/HSVIS/Typecheck.hs
+++ b/src/HSVIS/Typecheck.hs
@@ -1,13 +1,14 @@
+{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE EmptyDataDeriving #-}
{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TupleSections #-}
-{-# LANGUAGE GADTs #-}
+{-# LANGUAGE TypeFamilies #-}
module HSVIS.Typecheck (
StageTyped,
typecheck,
@@ -52,7 +53,7 @@ type instance X Pattern StageTC = (Range, CType)
type instance X RHS StageTC = CType
type instance X Expr StageTC = (Range, CType)
-data instance E Type StageTC = TUniVar Int deriving (Show, Eq, Ord)
+data instance E Type StageTC = TUniVar Int | TForallC (Name, CKind) CType deriving (Show, Eq, Ord)
data instance E Kind StageTC = KUniVar Int deriving (Show, Eq, Ord)
data instance E TypeSig StageTC deriving (Show)
@@ -60,6 +61,7 @@ type CProgram = Program StageTC
type CDataDef = DataDef StageTC
type CDataField = DataField StageTC
type CFunDef = FunDef StageTC
+type CTypeSig = TypeSig StageTC
type CFunEq = FunEq StageTC
type CKind = Kind StageTC
type CType = Type StageTC
@@ -80,7 +82,7 @@ type instance X Pattern StageTyped = TType
type instance X RHS StageTyped = TType
type instance X Expr StageTyped = TType
-data instance E Type StageTyped deriving (Show)
+data instance E Type StageTyped = TForall (Name, TKind) TType deriving (Show)
data instance E Kind StageTyped deriving (Show)
data instance E TypeSig StageTyped deriving (Show)
@@ -97,6 +99,8 @@ type TExpr = Expr StageTyped
instance Pretty (E Type StageTC) where
prettysPrec _ (TUniVar n) = showString ("?t" ++ show n)
+ prettysPrec d (TForallC (n, k) t) = showParen (d > 0) $
+ showString "forall (" . prettys n . showString " :: " . prettys k . showString "). " . prettys t
instance Pretty (E Kind StageTC) where
prettysPrec _ (KUniVar n) = showString ("?k" ++ show n)
@@ -140,17 +144,18 @@ newtype TCM a = TCM {
instance Functor TCM where
fmap f (TCM g) = TCM $ \ctx i env ->
- let (ds, cs, i', env', x) = g ctx i env
- in (ds, cs, i', env', f x)
+ let !(ds, cs, i', env', !x) = g ctx i env
+ !res = f x
+ in (ds, cs, i', env', res)
instance Applicative TCM where
- pure x = TCM $ \_ i env -> (mempty, mempty, i, env, x)
+ pure x = TCM $ \_ i env -> x `seq` (mempty, mempty, i, env, x)
(<*>) = ap
instance Monad TCM where
TCM f >>= g = TCM $ \ctx i1 env1 ->
- let (ds2, cs2, i2, env2, x) = f ctx i1 env1
- (ds3, cs3, i3, env3, y) = runTCM (g x) ctx i2 env2
+ let !(ds2, cs2, i2, env2, !x) = f ctx i1 env1
+ !(ds3, cs3, i3, env3, !y) = runTCM (g x) ctx i2 env2
in (ds2 <> ds3, cs2 <> cs3, i3, env3, y)
class Monad m => MonadRaise m where
@@ -232,14 +237,18 @@ getKind' :: Range -> Name -> TCM CKind
getKind' rng name = getKind name >>= \case
Nothing -> do
raise SError rng $ "Type not in scope: " ++ pretty name
- genKUniVar
+ k <- genKUniVar -- insert it now so that all occurrences of this out-of-scope name get the same kind
+ modifyTEnv (Map.insert name k)
+ return k
Just k -> return k
getType' :: Range -> Name -> TCM CType
getType' rng name = getType name >>= \case
Nothing -> do
raise SError rng $ "Variable not in scope: " ++ pretty name
- genUniVar (KType ())
+ t <- genUniVar (KType ()) -- insert it now so that all occurrences of this out-of-scope name get the same type
+ modifyVEnv (Map.insert name t)
+ return t
Just k -> return k
tcTop :: PProgram -> TCM TProgram
@@ -260,8 +269,9 @@ tcProgram (Program ddefs1 fdefs1) = do
let ddefs3 = map (substDdef kinduvars mempty) ddefs2
modifyTEnv (Map.map (substKind kinduvars))
- -- generate inverse constructors for all data types
- forM_ ddefs3 $ \ddef ->
+ -- generate constructor values and inverse constructors for all data types
+ forM_ ddefs3 $ \ddef -> do
+ modifyVEnv (Map.fromList (generateConstructors ddef) <>)
modifyICEnv (Map.fromList (generateInvCons ddef) <>)
traceM (unlines (map pretty ddefs3))
@@ -309,6 +319,13 @@ generateInvCons (DataDef k tname params cons) =
resty = TApp (KType ()) (TCon k tname) (map (uncurry TVar) params)
in [(cname, InvCon tyvars resty (map dataFieldType fields)) | (cname, fields) <- cons]
+generateConstructors :: CDataDef -> [(Name, CType)]
+generateConstructors (DataDef k tname params cons) =
+ let resty = TApp (KType ()) (TCon k tname) (map (uncurry TVar) params)
+ in [let funty = foldr (TFun (KType ())) resty (map dataFieldType fields)
+ in (cname, foldr (\(k1, n1) -> TExt (KType ()) . TForallC (n1, k1)) funty params)
+ | (cname, fields) <- cons]
+
promoteDownK :: Maybe CKind -> TCM CKind
promoteDownK Nothing = genKUniVar
promoteDownK (Just k) = return k
@@ -338,7 +355,7 @@ kcType' :: forall ext ret. Monoid ext => KCTypeMode ext ret -> Maybe CKind -> PT
kcType' mode mdown = \case
TApp rng t ts -> do
(ext1, t') <- kcType' mode Nothing t
- (ext2, ts') <- sequence <$> mapM (kcType' mode Nothing) ts
+ (ext2, ts') <- sequence <$> mapM (kcType' mode Nothing) ts -- TODO: give more useful down kinds
retk <- promoteDownK mdown
let expected = foldr (KFun ()) retk (map extOf ts')
emit $ CEqK (extOf t') expected rng
@@ -346,22 +363,17 @@ kcType' mode mdown = \case
TTup rng ts -> do
(ext, ts') <- sequence <$> mapM (kcType' mode (Just (KType ()))) ts
- forM_ (zip (map extOf ts) ts') $ \(trng, ct) ->
- emit $ CEqK (extOf ct) (KType ()) trng
downEqK rng mdown (KType ())
return (ext, TTup (KType ()) ts')
TList rng t -> do
(ext, t') <- kcType' mode (Just (KType ())) t
- emit $ CEqK (extOf t') (KType ()) (extOf t)
downEqK rng mdown (KType ())
return (ext, TList (KType ()) t')
TFun rng t1 t2 -> do
(ext1, t1') <- kcType' mode (Just (KType ())) t1
(ext2, t2') <- kcType' mode (Just (KType ())) t2
- emit $ CEqK (extOf t1') (KType ()) (extOf t1)
- emit $ CEqK (extOf t2') (KType ()) (extOf t2)
downEqK rng mdown (KType ())
return (ext1 <> ext2, TFun (KType ()) t1' t2')
@@ -371,61 +383,66 @@ kcType' mode mdown = \case
return (mempty, TCon k n)
TVar rng n -> do
- k <- getKind' rng n
- downEqK rng mdown k
- return (case mode of KCTMNormal -> ()
- KCTMOpen -> MMap.singleton n (pure k)
- ,TVar k n)
-
- TForall rng n t -> do -- implicit forall
- k1 <- genKUniVar
+ mk <- getKind n
+ case mk of
+ Nothing -> do
+ k <- promoteDownK mdown
+ return (case mode of KCTMNormal -> ()
+ KCTMOpen -> MMap.singleton n (pure k)
+ ,TVar k n)
+ Just k -> do
+ downEqK rng mdown k
+ -- TODO: need to instantiate top-level foralls in k here
+ return (mempty, TVar k n)
+
+ TExt rng (TForallP n mk1 t) -> do -- implicit forall
+ k1 <- maybe genKUniVar (return . checkKind) mk1
k2 <- genKUniVar
downEqK rng mdown k2
(ext, t') <- scopeTEnv $ do
modifyTEnv (Map.insert n k1)
kcType' mode (Just k2) t
- return (ext, TForall k2 n t') -- not 'k1 -> k2' because the forall is implicit
+ return (ext, TExt k2 (TForallC (n, k1) t')) -- 'k2', not 'k1 -> k2', because the forall is implicit
tcFunDefBlock :: [PFunDef] -> TCM [CFunDef]
tcFunDefBlock fdefs = do
- -- generate preliminary unification variables for the functions' types
- bound <- mapM (\(FunDef _ n _ _) -> (n,) <$> genUniVar (KType ())) fdefs
- defs' <- forM fdefs $ \def@(FunDef _ name _ _) ->
- scopeVEnv $ do
- modifyVEnv (Map.fromList [(n, t) | (n, t) <- bound, n /= name] <>)
- tcFunDef def
-
- -- take the actual found types for typechecking the body (and link them
- -- to the variables generated above)
- let bound2 = map (\(FunDef _ n (TypeSig _ ty) _) -> (n, ty)) defs'
- forM_ (zip3 fdefs bound bound2) $ \(fdef, (_, tvar), (_, ty)) ->
- emit $ CEq ty tvar (extOf fdef) -- which is expected/observed? which range? /shrug/
+ -- collect types for each of the bound functions, or unification variables if there's no type signature
+ bound <- mapM (\(FunDef _ n ts _) -> (n,) <$> tcTypeSig ts) fdefs
+ defs' <- scopeVEnv $ do
+ modifyVEnv (Map.fromList [(n, t) | (n, TypeSig _ t) <- bound] <>)
+ forM (zip fdefs bound) $ \(def, (_, sig)) ->
+ tcFunDef sig def
+
+ -- -- take the actual found types for typechecking the body (and link them
+ -- -- to the variables generated above)
+ -- let bound2 = map (\(FunDef _ n (TypeSig _ ty) _) -> (n, ty)) defs'
+ -- forM_ (zip3 fdefs bound bound2) $ \(fdef, (_, tvar), (_, ty)) ->
+ -- emit $ CEq ty tvar (extOf fdef) -- which is expected/observed? which range? /shrug/
return defs'
-tcFunDef :: PFunDef -> TCM CFunDef
-tcFunDef (FunDef rng name msig eqs) = do
+-- | Just the identity, because there's nothing to be checked in our simple kinds
+checkKind :: PKind -> CKind
+checkKind (KType ()) = KType ()
+checkKind (KFun () k1 k2) = KFun () (checkKind k1) (checkKind k2)
+
+tcTypeSig :: PTypeSig -> TCM CTypeSig
+tcTypeSig (TypeSig () sig) = do
+ (typ, freetvars) <- kcType KCTMOpen (Just (KType ())) sig
+ return $ TypeSig (Just (extOf sig)) $ foldr (\(n, k) -> TExt (KType ()) . TForallC (n, k)) typ (Map.assocs freetvars)
+tcTypeSig (TypeSigExt () NoTypeSig) = TypeSig Nothing <$> genUniVar (KType ())
+
+-- | Typechecks the function definition, but assumes its signature has already
+-- been checked, and is passed separately. Thus the PTypeSig in the PFunDef is
+-- ignored.
+tcFunDef :: CTypeSig -> PFunDef -> TCM CFunDef
+tcFunDef typesig@(TypeSig _ funtyp) (FunDef rng name _ eqs) = do
when (not $ allEq (fmap (length . funeqPats) eqs)) $
raise SError rng "Function equations have differing numbers of arguments"
- (typ, msigrng) <- case msig of
- TypeSig _ sig -> do
- (typ, freetvars) <- kcType KCTMOpen (Just (KType ())) sig
- TODO -- We need to check that these free type variables do not escape.
- -- Perhaps with levels on unification variables? Associate a level
- -- to a generated uvar, and increment the global level counter when
- -- passing below a forall.
- -- But how do we deal with functions without a type signature
- -- anyway? We should be able to infer a polymorphic type for them.
- return (foldr (\(n, k) -> TForall k n) typ (Map.assocs freetvars)
- ,Just (extOf sig))
- TypeSigExt _ NoTypeSig -> (,Nothing) <$> genUniVar (KType ())
-
- eqs' <- scopeVEnv $ do
- modifyVEnv (Map.insert name typ) -- allow function to be recursive
- mapM (tcFunEq typ) eqs
-
- return (FunDef rng name (TypeSig msigrng typ) eqs')
+ eqs' <- scopeVEnv $ mapM (tcFunEq funtyp) eqs
+
+ return (FunDef rng name typesig eqs')
tcFunEq :: CType -> PFunEq -> TCM CFunEq
tcFunEq down (FunEq rng name pats rhs) = do
@@ -503,13 +520,15 @@ tcExpr down = \case
EVar rng n -> do
ty <- getType' rng n
- emit $ CEq down ty rng
- return $ EVar (rng, ty) n
+ ty' <- instantiateTForallsUni ty
+ emit $ CEq down ty' rng
+ return $ EVar (rng, ty') n
ECon rng n -> do
ty <- getType' rng n
- emit $ CEq down ty rng
- return $ EVar (rng, ty) n
+ ty' <- instantiateTForallsUni ty
+ emit $ CEq down ty' rng
+ return $ EVar (rng, ty') n
EList rng es -> do
eltty <- genUniVar (KType ())
@@ -574,6 +593,14 @@ tcExpr down = \case
EError rng -> return $ EError (rng, down)
+instantiateTForallsUni :: CType -> TCM CType
+instantiateTForallsUni = go mempty
+ where
+ go sub (TExt _ (TForallC (n, k1) t)) = do
+ var <- genUniVar k1
+ go (Map.insert n var sub) t
+ go sub t = return $ substType mempty mempty sub t
+
unfoldFunTy :: Range -> Int -> CType -> TCM ([CType], CType)
unfoldFunTy _ n t | n <= 0 = return ([], t)
unfoldFunTy rng n (TFun _ t1 t2) = do
@@ -680,7 +707,8 @@ solveTypeVars cs = do
(TFun _ t1 t2, TFun _ t1' t2') -> reduce t1 t1' rng <> reduce t2 t2' rng
(TCon _ n1, TCon _ n2) | n1 == n2 -> mempty
(TVar _ n1, TVar _ n2) | n1 == n2 -> mempty
- (TForall _ n1 t1, TForall k n2 t2) ->
+ -- TODO: this doesn't check that the types are kind-correct. Did we already check that?
+ (TExt _ (TForallC (n1, k) t1), TExt _ (TForallC (n2, _) t2)) ->
reduce t1 (substType mempty mempty (Map.singleton n2 (TVar k n1)) t2) rng
-- otherwise, this is a kind mismatch
@@ -693,8 +721,8 @@ solveTypeVars cs = do
typeSize (TFun _ t1 t2) = typeSize t1 + typeSize t2
typeSize (TCon _ _) = 1
typeSize (TVar _ _) = 1
- typeSize (TForall _ _ t) = 1 + typeSize t
typeSize (TExt _ TUniVar{}) = 2
+ typeSize (TExt _ (TForallC _ t)) = 1 + typeSize t
partitionConstrs :: Foldable t => t Constr -> (Bag (CType, CType, Range), Bag (CKind, CKind, Range))
partitionConstrs = foldMap $ \case CEq t1 t2 r -> (pure (t1, t2, r), mempty)
@@ -773,8 +801,8 @@ substType mk mt mtv = go
go (TFun k t1 t2) = TFun (goKind k) (go t1) (go t2)
go (TCon k n) = TCon (goKind k) n
go (TVar k n) = fromMaybe (TVar (goKind k) n) (Map.lookup n mtv)
- go (TForall k n t) = TForall (goKind k) n (go t)
go (TExt k (TUniVar v)) = fromMaybe (TExt (goKind k) (TUniVar v)) (Map.lookup v mt)
+ go (TExt k (TForallC (n, k1) t)) = TExt (goKind k) (TForallC (n, goKind k1) (go t))
goKind = substKind mk
@@ -847,16 +875,17 @@ finaliseExpr = \case
finaliseType :: MonadRaise m => Range -> CType -> m TType
finaliseType rng toptype = go toptype
where
+ go :: MonadRaise m => Type StageTC -> m TType
go (TApp k t ts) = TApp <$> finaliseKind k <*> go t <*> traverse go ts
go (TTup k ts) = TTup <$> finaliseKind k <*> traverse go ts
go (TList k t) = TList <$> finaliseKind k <*> go t
go (TFun k t1 t2) = TFun <$> finaliseKind k <*> go t1 <*> go t2
go (TCon k n) = TCon <$> finaliseKind k <*> pure n
go (TVar k n) = TVar <$> finaliseKind k <*> pure n
- go (TForall k n t) = TForall <$> finaliseKind k <*> pure n <*> go t
go t@(TExt k TUniVar{}) = do
raise SError rng $ "Ambiguous type unification variable " ++ pretty t ++ " in type: " ++ pretty toptype
TVar <$> finaliseKind k <*> pure (Name "$_error")
+ go (TExt k (TForallC (n, k1) t)) = TExt <$> finaliseKind k <*> (TForall <$> ((,) <$> pure n <*> finaliseKind k1) <*> go t)
finaliseKind :: MonadRaise m => CKind -> m TKind
finaliseKind = \case
@@ -874,8 +903,8 @@ typeUniVars = \case
TFun _ t1 t2 -> typeUniVars t1 <> typeUniVars t2
TCon _ _ -> mempty
TVar _ _ -> mempty
- TForall _ _ t -> typeUniVars t
TExt _ (TUniVar v) -> Set.singleton v
+ TExt _ (TForallC _ t) -> typeUniVars t
kindUniVars :: CKind -> Set Int
kindUniVars = \case