From 909b7a4eacaba7323ac44f7950e60e8956e4081c Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Fri, 22 Mar 2024 21:56:35 +0100 Subject: Working kind inference --- src/HSVIS/Diagnostic.hs | 22 ++++-- src/HSVIS/Parser.hs | 2 +- src/HSVIS/Typecheck.hs | 158 ++++++++++++++++++++++++------------------- src/HSVIS/Typecheck/Solve.hs | 15 ++-- 4 files changed, 114 insertions(+), 83 deletions(-) diff --git a/src/HSVIS/Diagnostic.hs b/src/HSVIS/Diagnostic.hs index 675482d..778fe34 100644 --- a/src/HSVIS/Diagnostic.hs +++ b/src/HSVIS/Diagnostic.hs @@ -27,9 +27,13 @@ instance Pretty Range where | y2 <= y1 + 1 = showString (show (y1 + 1) ++ ":" ++ show (x1 + 1) ++ "-" ++ show x2) | otherwise = showString ("(" ++ show (y1 + 1) ++ ":" ++ show (x1 + 1) ++ ")-(" ++ show (y2 + 1) ++ ":" ++ show x2 ++ ")") + +data Severity = SError | SWarning + deriving (Show) data Diagnostic = Diagnostic - { dFile :: FilePath -- ^ The file for which the diagnostic was rai sed + { dSeverity :: Severity -- ^ Error level + , dFile :: FilePath -- ^ The file for which the diagnostic was raised , dRange :: Range -- ^ Where in the file , dStk :: [String] -- ^ Stack of contexts (innermost at head) of the diagnostic , dSourceLine :: String -- ^ The line in the source file of the start of the range @@ -38,17 +42,23 @@ data Diagnostic = Diagnostic deriving (Show) printDiagnostic :: Diagnostic -> String -printDiagnostic (Diagnostic fp rng@(Range (Pos y1 x1) (Pos y2 x2)) stk srcline msg) = +printDiagnostic (Diagnostic sev fp rng@(Range (Pos y1 x1) (Pos y2 x2)) stk srcline msg) = let linenum = show (y1 + 1) locstr = pretty rng ncarets | y1 == y2 = max 1 (x2 - x1 + 1) | otherwise = length srcline - x1 caretsuffix | y1 == y2 = "" | otherwise = "..." - in intercalate "\n" $ - map (\descr -> "In " ++ descr ++ ":") (reverse stk) - ++ [fp ++ ":" ++ locstr ++ ": " ++ msg - ,map (\_ -> ' ') linenum ++ " |" + + mainLine = + (case sev of SError -> "Error: " + SWarning -> "Warning: ") + ++ fp ++ ":" ++ locstr ++ ": " ++ msg + revCtxTrace = reverse $ map (\(i, descr) -> "in " ++ descr ++ (if i == 0 then "" else ",")) + (zip [0::Int ..] (reverse stk)) + srcPointer = + [map (\_ -> ' ') linenum ++ " |" ,linenum ++ " | " ++ srcline ,map (\_ -> ' ') linenum ++ " | " ++ replicate x1 ' ' ++ replicate ncarets '^' ++ caretsuffix] + in intercalate "\n" $ [mainLine] ++ srcPointer ++ revCtxTrace diff --git a/src/HSVIS/Parser.hs b/src/HSVIS/Parser.hs index b4d8754..e89c679 100644 --- a/src/HSVIS/Parser.hs +++ b/src/HSVIS/Parser.hs @@ -896,7 +896,7 @@ raise fat msg = gets psCur >>= \pos -> raiseAt pos fat msg raiseAt :: (KnownFallible fail, FatalCtx fatal a) => Pos -> Fatality fatal -> String -> Parser fail a raiseAt pos fat msg = do Context { ctxFile = fp , ctxStack = stk, ctxLines = srcLines } <- ask - let err = Diagnostic fp (Range pos pos) stk (srcLines !! posLine pos) msg + let err = Diagnostic SError fp (Range pos pos) stk (srcLines !! posLine pos) msg case fat of Error -> dictate (pure err) -- Fatal -> confess (pure err) diff --git a/src/HSVIS/Typecheck.hs b/src/HSVIS/Typecheck.hs index c97064a..0347e81 100644 --- a/src/HSVIS/Typecheck.hs +++ b/src/HSVIS/Typecheck.hs @@ -5,11 +5,18 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TupleSections #-} -module HSVIS.Typecheck where +{-# LANGUAGE GADTs #-} +module HSVIS.Typecheck ( + StageTyped, + typecheck, + -- * Typed AST synonyms + -- TProgram, TDataDef, TFunDef, TFunEq, TKind, TType, TPattern, TRHS, TExpr, +) where import Control.Monad -import Data.Bifunctor (first) +import Data.Bifunctor (first, second) import Data.Foldable (toList) +import Data.List (find) import Data.Map.Strict (Map) import Data.Maybe (fromMaybe) import Data.Monoid (Ap(..)) @@ -89,8 +96,8 @@ typecheck :: FilePath -> String -> PProgram -> ([Diagnostic], TProgram) typecheck fp source prog = let (ds1, cs, _, _, progtc) = runTCM (tcProgram prog) (fp, source) 1 (Env mempty mempty) - (ds2, sub) = solveConstrs cs - in (toList (ds1 <> ds2), substProg sub progtc) + (ds2, subK, subT) = solveConstrs cs + in (toList (ds1 <> ds2), doneProg subK subT progtc) data Constr -- Equality constraints: "left" must be equal to "right" because of the thing @@ -127,9 +134,9 @@ instance Monad TCM where (ds3, cs3, i3, env3, y) = runTCM (g x) ctx i2 env2 in (ds2 <> ds3, cs2 <> cs3, i3, env3, y) -raise :: Range -> String -> TCM () -raise rng@(Range (Pos y _) _) msg = TCM $ \(fp, source) i env -> - (pure (Diagnostic fp rng [] (lines source !! y) msg), mempty, i, env, ()) +raise :: Severity -> Range -> String -> TCM () +raise sev rng@(Range (Pos y _) _) msg = TCM $ \(fp, source) i env -> + (pure (Diagnostic sev fp rng [] (lines source !! y) msg), mempty, i, env, ()) emit :: Constr -> TCM () emit c = TCM $ \_ i env -> (mempty, pure c, i, env, ()) @@ -192,30 +199,31 @@ genUniVar k = TExt k . TUniVar <$> genId getKind' :: Range -> Name -> TCM CKind getKind' rng name = getKind name >>= \case Nothing -> do - raise rng $ "Type not in scope: " ++ pretty name + raise SError rng $ "Type not in scope: " ++ pretty name genKUniVar Just k -> return k getType' :: Range -> Name -> TCM CType getType' rng name = getType name >>= \case Nothing -> do - raise rng $ "Variable not in scope: " ++ pretty name + raise SError rng $ "Variable not in scope: " ++ pretty name genUniVar (KType ()) Just k -> return k tcProgram :: PProgram -> TCM CProgram -tcProgram (Program ddefs fdefs) = do - (kconstrs, ddefs') <- collectConstraints isCEqK $ do - mapM_ prepareDataDef ddefs - mapM tcDataDef ddefs +tcProgram (Program ddefs1 fdefs1) = do + (kconstrs, ddefs2) <- collectConstraints isCEqK $ do + mapM_ prepareDataDef ddefs1 + mapM tcDataDef ddefs1 - solveKindVars kconstrs + kinduvars <- solveKindVars kconstrs + let ddefs3 = map (substDdef kinduvars mempty) ddefs2 - traceM (unlines (map pretty ddefs')) + traceM (unlines (map pretty ddefs3)) - fdefs' <- mapM tcFunDef fdefs + fdefs2 <- mapM tcFunDef fdefs1 - return (Program ddefs' fdefs') + return (Program ddefs3 fdefs2) prepareDataDef :: PDataDef -> TCM () prepareDataDef (DataDef _ name params _) = do @@ -224,7 +232,7 @@ prepareDataDef (DataDef _ name params _) = do modifyTEnv (Map.insert name k) -- Assumes that the kind of the name itself has already been registered with --- the correct arity (this is doen by prepareDataDef). +-- the correct arity (this is done by prepareDataDef). tcDataDef :: PDataDef -> TCM CDataDef tcDataDef (DataDef rng name params cons) = do kd <- getKind' rng name @@ -292,9 +300,9 @@ kcType mdown = \case return (TVar k n) tcFunDef :: PFunDef -> TCM CFunDef -tcFunDef (FunDef _ name msig eqs) = do +tcFunDef (FunDef rng name msig eqs) = do when (not $ allEq (fmap (length . funeqPats) eqs)) $ - raise (sconcatne (fmap extOf eqs)) "Function equations have differing numbers of arguments" + raise SError rng "Function equations have differing numbers of arguments" typ <- case msig of TypeSig sig -> kcType (Just (KType ())) sig @@ -305,52 +313,36 @@ tcFunDef (FunDef _ name msig eqs) = do return (FunDef typ name (TypeSig typ) eqs') tcFunEq :: CType -> PFunEq -> TCM CFunEq -tcFunEq = error "tcFunEq" - -newtype SolveM v t m a = SolveM (Map v (Bag t) -> Map v t -> m (a, Map v (Bag t), Map v t)) -instance Monad m => Functor (SolveM v t m) where - fmap f (SolveM g) = SolveM $ \m r -> do (x, m', r') <- g m r - return (f x, m', r') -instance Monad m => Applicative (SolveM v t m) where - pure x = SolveM $ \m r -> return (x, m, r) - (<*>) = ap -instance Monad m => Monad (SolveM v t m) where - SolveM f >>= g = SolveM $ \m r -> do (x, m1, r1) <- f m r - let SolveM h = g x - h m1 r1 - -solvemStateGet :: Monad m => SolveM v t m (Map v (Bag t)) -solvemStateGet = SolveM $ \m r -> return (m, m, r) - -solvemStateUpdate :: Monad m => (Map v (Bag t) -> Map v (Bag t)) -> SolveM v t m () -solvemStateUpdate f = SolveM $ \m r -> return ((), f m, r) - -solvemLogUpdate :: Monad m => (Map v t -> Map v t) -> SolveM v t m () -solvemLogUpdate f = SolveM $ \m r -> return ((), m, f r) - -solvemStateVars :: Monad m => SolveM v t m [v] -solvemStateVars = Map.keys <$> solvemStateGet +tcFunEq down (FunEq rng name pats rhs) = error "tcFunEq" -solvemStateRHS :: (Ord v, Monad m) => v -> SolveM v t m (Bag t) -solvemStateRHS v = fromMaybe mempty . Map.lookup v <$> solvemStateGet - -solvemStateSet :: (Ord v, Monad m) => v -> Bag t -> SolveM v t m () -solvemStateSet v b = solvemStateUpdate (Map.insert v b) - -solvemLogEq :: (Ord v, Monad m) => v -> t -> SolveM v t m () -solvemLogEq v t = solvemLogUpdate (Map.insert v t) - -solveKindVars :: Bag (CKind, CKind, Range) -> TCM () +solveKindVars :: Bag (CKind, CKind, Range) -> TCM (Map Int CKind) solveKindVars cs = do - traceShowM cs - traceShowM $ solveConstraints - reduce - (foldMap pure . kindUniVars) - (\v repl -> substKind (Map.singleton v repl)) - (\case KExt () (KUniVar v) -> Just v - _ -> Nothing) - kindSize - (toList cs) + let (asg, errs) = + solveConstraints + reduce + (foldMap pure . kindUniVars) + substKind + (\case KExt () (KUniVar v) -> Just v + _ -> Nothing) + kindSize + (toList cs) + + forM_ errs $ \case + UEUnequal k1 k2 rng -> + raise SError rng $ + "Kind mismatch:\n\ + \- " ++ pretty k1 ++ "\n\ + \- " ++ pretty k2 + UERecursive uvar k rng -> + raise SError rng $ + "Kind cannot be recursive: " ++ pretty (KExt () (KUniVar uvar)) ++ " = " ++ pretty k + + -- default unconstrained kind variables to Type + let unconstrKUVars = foldMap kindUniVars (Map.elems asg) Set.\\ Map.keysSet asg + defaults = Map.fromList (map (,KType ()) (toList unconstrKUVars)) + asg' = Map.map (substKind defaults) asg <> defaults + + return asg' where reduce :: CKind -> CKind -> Range -> (Bag (Int, CKind, Range), Bag (CKind, CKind, Range)) reduce lhs rhs rng = case (lhs, rhs) of @@ -369,18 +361,44 @@ solveKindVars cs = do kindSize :: CKind -> Int kindSize KType{} = 1 kindSize (KFun () a b) = 1 + kindSize a + kindSize b - kindSize (KExt () KUniVar{}) = 1 + kindSize (KExt () KUniVar{}) = 2 -solveConstrs :: Bag Constr -> (Bag Diagnostic, Map Name TType) +solveConstrs :: Bag Constr -> (Bag Diagnostic, Map Int TKind, Map Int TType) solveConstrs = error "solveConstrs" -substProg :: Map Name TType -> CProgram -> TProgram +substProg :: Map Int CKind -- ^ Kind variable instantiations + -> Map Int CType -- ^ Type variable instantiations + -> CProgram + -> CProgram substProg = error "substProg" +substDdef :: Map Int CKind -> Map Int CType -> CDataDef -> CDataDef +substDdef mk mt (DataDef () name pars cons) = + DataDef () name + (map (first (substKind mk)) pars) + (map (second (map (substType mk mt))) cons) + +substType :: Map Int CKind -> Map Int CType -> CType -> CType +substType mk mt = \case + TApp k t ts -> TApp (substKind mk k) (substType mk mt t) (map (substType mk mt) ts) + TTup k ts -> TTup (substKind mk k) (map (substType mk mt) ts) + TList k t -> TList (substKind mk k) (substType mk mt t) + TFun k t1 t2 -> TFun (substKind mk k) (substType mk mt t1) (substType mk mt t2) + TCon k n -> TCon (substKind mk k) n + TVar k n -> TVar (substKind mk k) n + t@(TExt _ (TUniVar v)) -> fromMaybe t (Map.lookup v mt) + substKind :: Map Int CKind -> CKind -> CKind -substKind _ k@KType{} = k -substKind m (KFun () k1 k2) = KFun () (substKind m k1) (substKind m k2) -substKind m k@(KExt () (KUniVar v)) = fromMaybe k (Map.lookup v m) +substKind m = \case + KType () -> KType () + KFun () k1 k2 -> KFun () (substKind m k1) (substKind m k2) + k@(KExt () (KUniVar v)) -> fromMaybe k (Map.lookup v m) + +doneProg :: Map Int TKind -- ^ Kind variable instantiations + -> Map Int TType -- ^ Type variable instantiations + -> CProgram + -> TProgram +doneProg = error "doneProg" kindUniVars :: CKind -> Set Int kindUniVars = \case diff --git a/src/HSVIS/Typecheck/Solve.hs b/src/HSVIS/Typecheck/Solve.hs index 5f51abe..7250e79 100644 --- a/src/HSVIS/Typecheck/Solve.hs +++ b/src/HSVIS/Typecheck/Solve.hs @@ -9,7 +9,7 @@ module HSVIS.Typecheck.Solve ( import Control.Monad (guard, (>=>)) import Data.Bifunctor (Bifunctor(..)) import Data.Foldable (toList, foldl') -import Data.List (sortBy, minimumBy, groupBy) +import Data.List (sortBy, minimumBy, groupBy, intercalate) import Data.Ord (comparing) import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map @@ -66,8 +66,8 @@ solveConstraints => (t -> t -> Range -> (Bag (v, t, Range), Bag (t, t, Range))) -- | Free variables in a type -> (t -> Bag v) - -- | \v repl term -> Substitute v by repl in term - -> (v -> t -> t -> t) + -- | \mapping term -> term with variables in the mapping substituted by their values + -> (Map v t -> t -> t) -- | Detect bare-variable types -> (t -> Maybe v) -- | Some kind of size measure on types @@ -118,19 +118,22 @@ solveConstraints reduce frees subst detect size = \tupcs -> $ Map.assocs m' case msmallestvar of - Nothing -> return $ applyLog eqlog mempty + Nothing -> do + traceM $ "[solver] Log = [" ++ intercalate ", " [show v ++ " = " ++ pretty t | (v, t) <- eqlog] ++ "]" + return $ applyLog eqlog mempty Just (var, RConstr smallrhs _) -> do + traceM $ "[solver] Retiring " ++ show var ++ " = " ++ pretty smallrhs let (_, otherrhss) = bagPartition (guard . (== smallrhs) . rconType) (m' Map.! var) let (newcs, errs) = foldMap (reduce' . unsplitConstr smallrhs) (dedupRCs (toList otherrhss)) (fmap constrUnequal errs, ()) -- write the errors let m'' = Map.unionWith (<>) - (Map.map (fmap @Bag (fmap @RConstr (subst var smallrhs))) + (Map.map (fmap @Bag (fmap @RConstr (subst (Map.singleton var smallrhs)))) (Map.delete var m')) (Map.fromListWith (<>) (map (second pure . splitConstr) (toList newcs))) loop m'' ((var, smallrhs) : eqlog) applyLog :: [(v, t)] -> Map v t -> Map v t - applyLog ((v, t) : l) m = applyLog l $ Map.insert v t (Map.map (subst v t) m) + applyLog ((v, t) : l) m = applyLog l $ Map.insert v (subst m t) m applyLog [] m = m -- If there are multiple sources for the same cosntraint, only one of them is kept. -- cgit v1.2.3-70-g09d2