diff options
Diffstat (limited to 'src/HSVIS/Typecheck.hs')
-rw-r--r-- | src/HSVIS/Typecheck.hs | 167 |
1 files changed, 98 insertions, 69 deletions
diff --git a/src/HSVIS/Typecheck.hs b/src/HSVIS/Typecheck.hs index 1e46a99..8b642eb 100644 --- a/src/HSVIS/Typecheck.hs +++ b/src/HSVIS/Typecheck.hs @@ -1,13 +1,14 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE EmptyDataDeriving #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TupleSections #-} -{-# LANGUAGE GADTs #-} +{-# LANGUAGE TypeFamilies #-} module HSVIS.Typecheck ( StageTyped, typecheck, @@ -52,7 +53,7 @@ type instance X Pattern StageTC = (Range, CType) type instance X RHS StageTC = CType type instance X Expr StageTC = (Range, CType) -data instance E Type StageTC = TUniVar Int deriving (Show, Eq, Ord) +data instance E Type StageTC = TUniVar Int | TForallC (Name, CKind) CType deriving (Show, Eq, Ord) data instance E Kind StageTC = KUniVar Int deriving (Show, Eq, Ord) data instance E TypeSig StageTC deriving (Show) @@ -60,6 +61,7 @@ type CProgram = Program StageTC type CDataDef = DataDef StageTC type CDataField = DataField StageTC type CFunDef = FunDef StageTC +type CTypeSig = TypeSig StageTC type CFunEq = FunEq StageTC type CKind = Kind StageTC type CType = Type StageTC @@ -80,7 +82,7 @@ type instance X Pattern StageTyped = TType type instance X RHS StageTyped = TType type instance X Expr StageTyped = TType -data instance E Type StageTyped deriving (Show) +data instance E Type StageTyped = TForall (Name, TKind) TType deriving (Show) data instance E Kind StageTyped deriving (Show) data instance E TypeSig StageTyped deriving (Show) @@ -97,6 +99,8 @@ type TExpr = Expr StageTyped instance Pretty (E Type StageTC) where prettysPrec _ (TUniVar n) = showString ("?t" ++ show n) + prettysPrec d (TForallC (n, k) t) = showParen (d > 0) $ + showString "forall (" . prettys n . showString " :: " . prettys k . showString "). " . prettys t instance Pretty (E Kind StageTC) where prettysPrec _ (KUniVar n) = showString ("?k" ++ show n) @@ -140,17 +144,18 @@ newtype TCM a = TCM { instance Functor TCM where fmap f (TCM g) = TCM $ \ctx i env -> - let (ds, cs, i', env', x) = g ctx i env - in (ds, cs, i', env', f x) + let !(ds, cs, i', env', !x) = g ctx i env + !res = f x + in (ds, cs, i', env', res) instance Applicative TCM where - pure x = TCM $ \_ i env -> (mempty, mempty, i, env, x) + pure x = TCM $ \_ i env -> x `seq` (mempty, mempty, i, env, x) (<*>) = ap instance Monad TCM where TCM f >>= g = TCM $ \ctx i1 env1 -> - let (ds2, cs2, i2, env2, x) = f ctx i1 env1 - (ds3, cs3, i3, env3, y) = runTCM (g x) ctx i2 env2 + let !(ds2, cs2, i2, env2, !x) = f ctx i1 env1 + !(ds3, cs3, i3, env3, !y) = runTCM (g x) ctx i2 env2 in (ds2 <> ds3, cs2 <> cs3, i3, env3, y) class Monad m => MonadRaise m where @@ -232,14 +237,18 @@ getKind' :: Range -> Name -> TCM CKind getKind' rng name = getKind name >>= \case Nothing -> do raise SError rng $ "Type not in scope: " ++ pretty name - genKUniVar + k <- genKUniVar -- insert it now so that all occurrences of this out-of-scope name get the same kind + modifyTEnv (Map.insert name k) + return k Just k -> return k getType' :: Range -> Name -> TCM CType getType' rng name = getType name >>= \case Nothing -> do raise SError rng $ "Variable not in scope: " ++ pretty name - genUniVar (KType ()) + t <- genUniVar (KType ()) -- insert it now so that all occurrences of this out-of-scope name get the same type + modifyVEnv (Map.insert name t) + return t Just k -> return k tcTop :: PProgram -> TCM TProgram @@ -260,8 +269,9 @@ tcProgram (Program ddefs1 fdefs1) = do let ddefs3 = map (substDdef kinduvars mempty) ddefs2 modifyTEnv (Map.map (substKind kinduvars)) - -- generate inverse constructors for all data types - forM_ ddefs3 $ \ddef -> + -- generate constructor values and inverse constructors for all data types + forM_ ddefs3 $ \ddef -> do + modifyVEnv (Map.fromList (generateConstructors ddef) <>) modifyICEnv (Map.fromList (generateInvCons ddef) <>) traceM (unlines (map pretty ddefs3)) @@ -309,6 +319,13 @@ generateInvCons (DataDef k tname params cons) = resty = TApp (KType ()) (TCon k tname) (map (uncurry TVar) params) in [(cname, InvCon tyvars resty (map dataFieldType fields)) | (cname, fields) <- cons] +generateConstructors :: CDataDef -> [(Name, CType)] +generateConstructors (DataDef k tname params cons) = + let resty = TApp (KType ()) (TCon k tname) (map (uncurry TVar) params) + in [let funty = foldr (TFun (KType ())) resty (map dataFieldType fields) + in (cname, foldr (\(k1, n1) -> TExt (KType ()) . TForallC (n1, k1)) funty params) + | (cname, fields) <- cons] + promoteDownK :: Maybe CKind -> TCM CKind promoteDownK Nothing = genKUniVar promoteDownK (Just k) = return k @@ -338,7 +355,7 @@ kcType' :: forall ext ret. Monoid ext => KCTypeMode ext ret -> Maybe CKind -> PT kcType' mode mdown = \case TApp rng t ts -> do (ext1, t') <- kcType' mode Nothing t - (ext2, ts') <- sequence <$> mapM (kcType' mode Nothing) ts + (ext2, ts') <- sequence <$> mapM (kcType' mode Nothing) ts -- TODO: give more useful down kinds retk <- promoteDownK mdown let expected = foldr (KFun ()) retk (map extOf ts') emit $ CEqK (extOf t') expected rng @@ -346,22 +363,17 @@ kcType' mode mdown = \case TTup rng ts -> do (ext, ts') <- sequence <$> mapM (kcType' mode (Just (KType ()))) ts - forM_ (zip (map extOf ts) ts') $ \(trng, ct) -> - emit $ CEqK (extOf ct) (KType ()) trng downEqK rng mdown (KType ()) return (ext, TTup (KType ()) ts') TList rng t -> do (ext, t') <- kcType' mode (Just (KType ())) t - emit $ CEqK (extOf t') (KType ()) (extOf t) downEqK rng mdown (KType ()) return (ext, TList (KType ()) t') TFun rng t1 t2 -> do (ext1, t1') <- kcType' mode (Just (KType ())) t1 (ext2, t2') <- kcType' mode (Just (KType ())) t2 - emit $ CEqK (extOf t1') (KType ()) (extOf t1) - emit $ CEqK (extOf t2') (KType ()) (extOf t2) downEqK rng mdown (KType ()) return (ext1 <> ext2, TFun (KType ()) t1' t2') @@ -371,61 +383,66 @@ kcType' mode mdown = \case return (mempty, TCon k n) TVar rng n -> do - k <- getKind' rng n - downEqK rng mdown k - return (case mode of KCTMNormal -> () - KCTMOpen -> MMap.singleton n (pure k) - ,TVar k n) - - TForall rng n t -> do -- implicit forall - k1 <- genKUniVar + mk <- getKind n + case mk of + Nothing -> do + k <- promoteDownK mdown + return (case mode of KCTMNormal -> () + KCTMOpen -> MMap.singleton n (pure k) + ,TVar k n) + Just k -> do + downEqK rng mdown k + -- TODO: need to instantiate top-level foralls in k here + return (mempty, TVar k n) + + TExt rng (TForallP n mk1 t) -> do -- implicit forall + k1 <- maybe genKUniVar (return . checkKind) mk1 k2 <- genKUniVar downEqK rng mdown k2 (ext, t') <- scopeTEnv $ do modifyTEnv (Map.insert n k1) kcType' mode (Just k2) t - return (ext, TForall k2 n t') -- not 'k1 -> k2' because the forall is implicit + return (ext, TExt k2 (TForallC (n, k1) t')) -- 'k2', not 'k1 -> k2', because the forall is implicit tcFunDefBlock :: [PFunDef] -> TCM [CFunDef] tcFunDefBlock fdefs = do - -- generate preliminary unification variables for the functions' types - bound <- mapM (\(FunDef _ n _ _) -> (n,) <$> genUniVar (KType ())) fdefs - defs' <- forM fdefs $ \def@(FunDef _ name _ _) -> - scopeVEnv $ do - modifyVEnv (Map.fromList [(n, t) | (n, t) <- bound, n /= name] <>) - tcFunDef def - - -- take the actual found types for typechecking the body (and link them - -- to the variables generated above) - let bound2 = map (\(FunDef _ n (TypeSig _ ty) _) -> (n, ty)) defs' - forM_ (zip3 fdefs bound bound2) $ \(fdef, (_, tvar), (_, ty)) -> - emit $ CEq ty tvar (extOf fdef) -- which is expected/observed? which range? /shrug/ + -- collect types for each of the bound functions, or unification variables if there's no type signature + bound <- mapM (\(FunDef _ n ts _) -> (n,) <$> tcTypeSig ts) fdefs + defs' <- scopeVEnv $ do + modifyVEnv (Map.fromList [(n, t) | (n, TypeSig _ t) <- bound] <>) + forM (zip fdefs bound) $ \(def, (_, sig)) -> + tcFunDef sig def + + -- -- take the actual found types for typechecking the body (and link them + -- -- to the variables generated above) + -- let bound2 = map (\(FunDef _ n (TypeSig _ ty) _) -> (n, ty)) defs' + -- forM_ (zip3 fdefs bound bound2) $ \(fdef, (_, tvar), (_, ty)) -> + -- emit $ CEq ty tvar (extOf fdef) -- which is expected/observed? which range? /shrug/ return defs' -tcFunDef :: PFunDef -> TCM CFunDef -tcFunDef (FunDef rng name msig eqs) = do +-- | Just the identity, because there's nothing to be checked in our simple kinds +checkKind :: PKind -> CKind +checkKind (KType ()) = KType () +checkKind (KFun () k1 k2) = KFun () (checkKind k1) (checkKind k2) + +tcTypeSig :: PTypeSig -> TCM CTypeSig +tcTypeSig (TypeSig () sig) = do + (typ, freetvars) <- kcType KCTMOpen (Just (KType ())) sig + return $ TypeSig (Just (extOf sig)) $ foldr (\(n, k) -> TExt (KType ()) . TForallC (n, k)) typ (Map.assocs freetvars) +tcTypeSig (TypeSigExt () NoTypeSig) = TypeSig Nothing <$> genUniVar (KType ()) + +-- | Typechecks the function definition, but assumes its signature has already +-- been checked, and is passed separately. Thus the PTypeSig in the PFunDef is +-- ignored. +tcFunDef :: CTypeSig -> PFunDef -> TCM CFunDef +tcFunDef typesig@(TypeSig _ funtyp) (FunDef rng name _ eqs) = do when (not $ allEq (fmap (length . funeqPats) eqs)) $ raise SError rng "Function equations have differing numbers of arguments" - (typ, msigrng) <- case msig of - TypeSig _ sig -> do - (typ, freetvars) <- kcType KCTMOpen (Just (KType ())) sig - TODO -- We need to check that these free type variables do not escape. - -- Perhaps with levels on unification variables? Associate a level - -- to a generated uvar, and increment the global level counter when - -- passing below a forall. - -- But how do we deal with functions without a type signature - -- anyway? We should be able to infer a polymorphic type for them. - return (foldr (\(n, k) -> TForall k n) typ (Map.assocs freetvars) - ,Just (extOf sig)) - TypeSigExt _ NoTypeSig -> (,Nothing) <$> genUniVar (KType ()) - - eqs' <- scopeVEnv $ do - modifyVEnv (Map.insert name typ) -- allow function to be recursive - mapM (tcFunEq typ) eqs - - return (FunDef rng name (TypeSig msigrng typ) eqs') + eqs' <- scopeVEnv $ mapM (tcFunEq funtyp) eqs + + return (FunDef rng name typesig eqs') tcFunEq :: CType -> PFunEq -> TCM CFunEq tcFunEq down (FunEq rng name pats rhs) = do @@ -503,13 +520,15 @@ tcExpr down = \case EVar rng n -> do ty <- getType' rng n - emit $ CEq down ty rng - return $ EVar (rng, ty) n + ty' <- instantiateTForallsUni ty + emit $ CEq down ty' rng + return $ EVar (rng, ty') n ECon rng n -> do ty <- getType' rng n - emit $ CEq down ty rng - return $ EVar (rng, ty) n + ty' <- instantiateTForallsUni ty + emit $ CEq down ty' rng + return $ EVar (rng, ty') n EList rng es -> do eltty <- genUniVar (KType ()) @@ -574,6 +593,14 @@ tcExpr down = \case EError rng -> return $ EError (rng, down) +instantiateTForallsUni :: CType -> TCM CType +instantiateTForallsUni = go mempty + where + go sub (TExt _ (TForallC (n, k1) t)) = do + var <- genUniVar k1 + go (Map.insert n var sub) t + go sub t = return $ substType mempty mempty sub t + unfoldFunTy :: Range -> Int -> CType -> TCM ([CType], CType) unfoldFunTy _ n t | n <= 0 = return ([], t) unfoldFunTy rng n (TFun _ t1 t2) = do @@ -680,7 +707,8 @@ solveTypeVars cs = do (TFun _ t1 t2, TFun _ t1' t2') -> reduce t1 t1' rng <> reduce t2 t2' rng (TCon _ n1, TCon _ n2) | n1 == n2 -> mempty (TVar _ n1, TVar _ n2) | n1 == n2 -> mempty - (TForall _ n1 t1, TForall k n2 t2) -> + -- TODO: this doesn't check that the types are kind-correct. Did we already check that? + (TExt _ (TForallC (n1, k) t1), TExt _ (TForallC (n2, _) t2)) -> reduce t1 (substType mempty mempty (Map.singleton n2 (TVar k n1)) t2) rng -- otherwise, this is a kind mismatch @@ -693,8 +721,8 @@ solveTypeVars cs = do typeSize (TFun _ t1 t2) = typeSize t1 + typeSize t2 typeSize (TCon _ _) = 1 typeSize (TVar _ _) = 1 - typeSize (TForall _ _ t) = 1 + typeSize t typeSize (TExt _ TUniVar{}) = 2 + typeSize (TExt _ (TForallC _ t)) = 1 + typeSize t partitionConstrs :: Foldable t => t Constr -> (Bag (CType, CType, Range), Bag (CKind, CKind, Range)) partitionConstrs = foldMap $ \case CEq t1 t2 r -> (pure (t1, t2, r), mempty) @@ -773,8 +801,8 @@ substType mk mt mtv = go go (TFun k t1 t2) = TFun (goKind k) (go t1) (go t2) go (TCon k n) = TCon (goKind k) n go (TVar k n) = fromMaybe (TVar (goKind k) n) (Map.lookup n mtv) - go (TForall k n t) = TForall (goKind k) n (go t) go (TExt k (TUniVar v)) = fromMaybe (TExt (goKind k) (TUniVar v)) (Map.lookup v mt) + go (TExt k (TForallC (n, k1) t)) = TExt (goKind k) (TForallC (n, goKind k1) (go t)) goKind = substKind mk @@ -847,16 +875,17 @@ finaliseExpr = \case finaliseType :: MonadRaise m => Range -> CType -> m TType finaliseType rng toptype = go toptype where + go :: MonadRaise m => Type StageTC -> m TType go (TApp k t ts) = TApp <$> finaliseKind k <*> go t <*> traverse go ts go (TTup k ts) = TTup <$> finaliseKind k <*> traverse go ts go (TList k t) = TList <$> finaliseKind k <*> go t go (TFun k t1 t2) = TFun <$> finaliseKind k <*> go t1 <*> go t2 go (TCon k n) = TCon <$> finaliseKind k <*> pure n go (TVar k n) = TVar <$> finaliseKind k <*> pure n - go (TForall k n t) = TForall <$> finaliseKind k <*> pure n <*> go t go t@(TExt k TUniVar{}) = do raise SError rng $ "Ambiguous type unification variable " ++ pretty t ++ " in type: " ++ pretty toptype TVar <$> finaliseKind k <*> pure (Name "$_error") + go (TExt k (TForallC (n, k1) t)) = TExt <$> finaliseKind k <*> (TForall <$> ((,) <$> pure n <*> finaliseKind k1) <*> go t) finaliseKind :: MonadRaise m => CKind -> m TKind finaliseKind = \case @@ -874,8 +903,8 @@ typeUniVars = \case TFun _ t1 t2 -> typeUniVars t1 <> typeUniVars t2 TCon _ _ -> mempty TVar _ _ -> mempty - TForall _ _ t -> typeUniVars t TExt _ (TUniVar v) -> Set.singleton v + TExt _ (TForallC _ t) -> typeUniVars t kindUniVars :: CKind -> Set Int kindUniVars = \case |