aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-03-22 21:56:35 +0100
committerTom Smeding <tom@tomsmeding.com>2024-03-22 21:56:35 +0100
commit909b7a4eacaba7323ac44f7950e60e8956e4081c (patch)
tree313f87022729ec7776332c828703677c293c8ac2
parentcc61cdc000481f9dc88253342c328bdb99d048a4 (diff)
Working kind inference
-rw-r--r--src/HSVIS/Diagnostic.hs22
-rw-r--r--src/HSVIS/Parser.hs2
-rw-r--r--src/HSVIS/Typecheck.hs158
-rw-r--r--src/HSVIS/Typecheck/Solve.hs15
4 files changed, 114 insertions, 83 deletions
diff --git a/src/HSVIS/Diagnostic.hs b/src/HSVIS/Diagnostic.hs
index 675482d..778fe34 100644
--- a/src/HSVIS/Diagnostic.hs
+++ b/src/HSVIS/Diagnostic.hs
@@ -27,9 +27,13 @@ instance Pretty Range where
| y2 <= y1 + 1 = showString (show (y1 + 1) ++ ":" ++ show (x1 + 1) ++ "-" ++ show x2)
| otherwise =
showString ("(" ++ show (y1 + 1) ++ ":" ++ show (x1 + 1) ++ ")-(" ++ show (y2 + 1) ++ ":" ++ show x2 ++ ")")
+
+data Severity = SError | SWarning
+ deriving (Show)
data Diagnostic = Diagnostic
- { dFile :: FilePath -- ^ The file for which the diagnostic was rai sed
+ { dSeverity :: Severity -- ^ Error level
+ , dFile :: FilePath -- ^ The file for which the diagnostic was raised
, dRange :: Range -- ^ Where in the file
, dStk :: [String] -- ^ Stack of contexts (innermost at head) of the diagnostic
, dSourceLine :: String -- ^ The line in the source file of the start of the range
@@ -38,17 +42,23 @@ data Diagnostic = Diagnostic
deriving (Show)
printDiagnostic :: Diagnostic -> String
-printDiagnostic (Diagnostic fp rng@(Range (Pos y1 x1) (Pos y2 x2)) stk srcline msg) =
+printDiagnostic (Diagnostic sev fp rng@(Range (Pos y1 x1) (Pos y2 x2)) stk srcline msg) =
let linenum = show (y1 + 1)
locstr = pretty rng
ncarets | y1 == y2 = max 1 (x2 - x1 + 1)
| otherwise = length srcline - x1
caretsuffix | y1 == y2 = ""
| otherwise = "..."
- in intercalate "\n" $
- map (\descr -> "In " ++ descr ++ ":") (reverse stk)
- ++ [fp ++ ":" ++ locstr ++ ": " ++ msg
- ,map (\_ -> ' ') linenum ++ " |"
+
+ mainLine =
+ (case sev of SError -> "Error: "
+ SWarning -> "Warning: ")
+ ++ fp ++ ":" ++ locstr ++ ": " ++ msg
+ revCtxTrace = reverse $ map (\(i, descr) -> "in " ++ descr ++ (if i == 0 then "" else ","))
+ (zip [0::Int ..] (reverse stk))
+ srcPointer =
+ [map (\_ -> ' ') linenum ++ " |"
,linenum ++ " | " ++ srcline
,map (\_ -> ' ') linenum ++ " | " ++ replicate x1 ' ' ++
replicate ncarets '^' ++ caretsuffix]
+ in intercalate "\n" $ [mainLine] ++ srcPointer ++ revCtxTrace
diff --git a/src/HSVIS/Parser.hs b/src/HSVIS/Parser.hs
index b4d8754..e89c679 100644
--- a/src/HSVIS/Parser.hs
+++ b/src/HSVIS/Parser.hs
@@ -896,7 +896,7 @@ raise fat msg = gets psCur >>= \pos -> raiseAt pos fat msg
raiseAt :: (KnownFallible fail, FatalCtx fatal a) => Pos -> Fatality fatal -> String -> Parser fail a
raiseAt pos fat msg = do
Context { ctxFile = fp , ctxStack = stk, ctxLines = srcLines } <- ask
- let err = Diagnostic fp (Range pos pos) stk (srcLines !! posLine pos) msg
+ let err = Diagnostic SError fp (Range pos pos) stk (srcLines !! posLine pos) msg
case fat of
Error -> dictate (pure err)
-- Fatal -> confess (pure err)
diff --git a/src/HSVIS/Typecheck.hs b/src/HSVIS/Typecheck.hs
index c97064a..0347e81 100644
--- a/src/HSVIS/Typecheck.hs
+++ b/src/HSVIS/Typecheck.hs
@@ -5,11 +5,18 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TupleSections #-}
-module HSVIS.Typecheck where
+{-# LANGUAGE GADTs #-}
+module HSVIS.Typecheck (
+ StageTyped,
+ typecheck,
+ -- * Typed AST synonyms
+ -- TProgram, TDataDef, TFunDef, TFunEq, TKind, TType, TPattern, TRHS, TExpr,
+) where
import Control.Monad
-import Data.Bifunctor (first)
+import Data.Bifunctor (first, second)
import Data.Foldable (toList)
+import Data.List (find)
import Data.Map.Strict (Map)
import Data.Maybe (fromMaybe)
import Data.Monoid (Ap(..))
@@ -89,8 +96,8 @@ typecheck :: FilePath -> String -> PProgram -> ([Diagnostic], TProgram)
typecheck fp source prog =
let (ds1, cs, _, _, progtc) =
runTCM (tcProgram prog) (fp, source) 1 (Env mempty mempty)
- (ds2, sub) = solveConstrs cs
- in (toList (ds1 <> ds2), substProg sub progtc)
+ (ds2, subK, subT) = solveConstrs cs
+ in (toList (ds1 <> ds2), doneProg subK subT progtc)
data Constr
-- Equality constraints: "left" must be equal to "right" because of the thing
@@ -127,9 +134,9 @@ instance Monad TCM where
(ds3, cs3, i3, env3, y) = runTCM (g x) ctx i2 env2
in (ds2 <> ds3, cs2 <> cs3, i3, env3, y)
-raise :: Range -> String -> TCM ()
-raise rng@(Range (Pos y _) _) msg = TCM $ \(fp, source) i env ->
- (pure (Diagnostic fp rng [] (lines source !! y) msg), mempty, i, env, ())
+raise :: Severity -> Range -> String -> TCM ()
+raise sev rng@(Range (Pos y _) _) msg = TCM $ \(fp, source) i env ->
+ (pure (Diagnostic sev fp rng [] (lines source !! y) msg), mempty, i, env, ())
emit :: Constr -> TCM ()
emit c = TCM $ \_ i env -> (mempty, pure c, i, env, ())
@@ -192,30 +199,31 @@ genUniVar k = TExt k . TUniVar <$> genId
getKind' :: Range -> Name -> TCM CKind
getKind' rng name = getKind name >>= \case
Nothing -> do
- raise rng $ "Type not in scope: " ++ pretty name
+ raise SError rng $ "Type not in scope: " ++ pretty name
genKUniVar
Just k -> return k
getType' :: Range -> Name -> TCM CType
getType' rng name = getType name >>= \case
Nothing -> do
- raise rng $ "Variable not in scope: " ++ pretty name
+ raise SError rng $ "Variable not in scope: " ++ pretty name
genUniVar (KType ())
Just k -> return k
tcProgram :: PProgram -> TCM CProgram
-tcProgram (Program ddefs fdefs) = do
- (kconstrs, ddefs') <- collectConstraints isCEqK $ do
- mapM_ prepareDataDef ddefs
- mapM tcDataDef ddefs
+tcProgram (Program ddefs1 fdefs1) = do
+ (kconstrs, ddefs2) <- collectConstraints isCEqK $ do
+ mapM_ prepareDataDef ddefs1
+ mapM tcDataDef ddefs1
- solveKindVars kconstrs
+ kinduvars <- solveKindVars kconstrs
+ let ddefs3 = map (substDdef kinduvars mempty) ddefs2
- traceM (unlines (map pretty ddefs'))
+ traceM (unlines (map pretty ddefs3))
- fdefs' <- mapM tcFunDef fdefs
+ fdefs2 <- mapM tcFunDef fdefs1
- return (Program ddefs' fdefs')
+ return (Program ddefs3 fdefs2)
prepareDataDef :: PDataDef -> TCM ()
prepareDataDef (DataDef _ name params _) = do
@@ -224,7 +232,7 @@ prepareDataDef (DataDef _ name params _) = do
modifyTEnv (Map.insert name k)
-- Assumes that the kind of the name itself has already been registered with
--- the correct arity (this is doen by prepareDataDef).
+-- the correct arity (this is done by prepareDataDef).
tcDataDef :: PDataDef -> TCM CDataDef
tcDataDef (DataDef rng name params cons) = do
kd <- getKind' rng name
@@ -292,9 +300,9 @@ kcType mdown = \case
return (TVar k n)
tcFunDef :: PFunDef -> TCM CFunDef
-tcFunDef (FunDef _ name msig eqs) = do
+tcFunDef (FunDef rng name msig eqs) = do
when (not $ allEq (fmap (length . funeqPats) eqs)) $
- raise (sconcatne (fmap extOf eqs)) "Function equations have differing numbers of arguments"
+ raise SError rng "Function equations have differing numbers of arguments"
typ <- case msig of
TypeSig sig -> kcType (Just (KType ())) sig
@@ -305,52 +313,36 @@ tcFunDef (FunDef _ name msig eqs) = do
return (FunDef typ name (TypeSig typ) eqs')
tcFunEq :: CType -> PFunEq -> TCM CFunEq
-tcFunEq = error "tcFunEq"
-
-newtype SolveM v t m a = SolveM (Map v (Bag t) -> Map v t -> m (a, Map v (Bag t), Map v t))
-instance Monad m => Functor (SolveM v t m) where
- fmap f (SolveM g) = SolveM $ \m r -> do (x, m', r') <- g m r
- return (f x, m', r')
-instance Monad m => Applicative (SolveM v t m) where
- pure x = SolveM $ \m r -> return (x, m, r)
- (<*>) = ap
-instance Monad m => Monad (SolveM v t m) where
- SolveM f >>= g = SolveM $ \m r -> do (x, m1, r1) <- f m r
- let SolveM h = g x
- h m1 r1
-
-solvemStateGet :: Monad m => SolveM v t m (Map v (Bag t))
-solvemStateGet = SolveM $ \m r -> return (m, m, r)
-
-solvemStateUpdate :: Monad m => (Map v (Bag t) -> Map v (Bag t)) -> SolveM v t m ()
-solvemStateUpdate f = SolveM $ \m r -> return ((), f m, r)
-
-solvemLogUpdate :: Monad m => (Map v t -> Map v t) -> SolveM v t m ()
-solvemLogUpdate f = SolveM $ \m r -> return ((), m, f r)
-
-solvemStateVars :: Monad m => SolveM v t m [v]
-solvemStateVars = Map.keys <$> solvemStateGet
+tcFunEq down (FunEq rng name pats rhs) = error "tcFunEq"
-solvemStateRHS :: (Ord v, Monad m) => v -> SolveM v t m (Bag t)
-solvemStateRHS v = fromMaybe mempty . Map.lookup v <$> solvemStateGet
-
-solvemStateSet :: (Ord v, Monad m) => v -> Bag t -> SolveM v t m ()
-solvemStateSet v b = solvemStateUpdate (Map.insert v b)
-
-solvemLogEq :: (Ord v, Monad m) => v -> t -> SolveM v t m ()
-solvemLogEq v t = solvemLogUpdate (Map.insert v t)
-
-solveKindVars :: Bag (CKind, CKind, Range) -> TCM ()
+solveKindVars :: Bag (CKind, CKind, Range) -> TCM (Map Int CKind)
solveKindVars cs = do
- traceShowM cs
- traceShowM $ solveConstraints
- reduce
- (foldMap pure . kindUniVars)
- (\v repl -> substKind (Map.singleton v repl))
- (\case KExt () (KUniVar v) -> Just v
- _ -> Nothing)
- kindSize
- (toList cs)
+ let (asg, errs) =
+ solveConstraints
+ reduce
+ (foldMap pure . kindUniVars)
+ substKind
+ (\case KExt () (KUniVar v) -> Just v
+ _ -> Nothing)
+ kindSize
+ (toList cs)
+
+ forM_ errs $ \case
+ UEUnequal k1 k2 rng ->
+ raise SError rng $
+ "Kind mismatch:\n\
+ \- " ++ pretty k1 ++ "\n\
+ \- " ++ pretty k2
+ UERecursive uvar k rng ->
+ raise SError rng $
+ "Kind cannot be recursive: " ++ pretty (KExt () (KUniVar uvar)) ++ " = " ++ pretty k
+
+ -- default unconstrained kind variables to Type
+ let unconstrKUVars = foldMap kindUniVars (Map.elems asg) Set.\\ Map.keysSet asg
+ defaults = Map.fromList (map (,KType ()) (toList unconstrKUVars))
+ asg' = Map.map (substKind defaults) asg <> defaults
+
+ return asg'
where
reduce :: CKind -> CKind -> Range -> (Bag (Int, CKind, Range), Bag (CKind, CKind, Range))
reduce lhs rhs rng = case (lhs, rhs) of
@@ -369,18 +361,44 @@ solveKindVars cs = do
kindSize :: CKind -> Int
kindSize KType{} = 1
kindSize (KFun () a b) = 1 + kindSize a + kindSize b
- kindSize (KExt () KUniVar{}) = 1
+ kindSize (KExt () KUniVar{}) = 2
-solveConstrs :: Bag Constr -> (Bag Diagnostic, Map Name TType)
+solveConstrs :: Bag Constr -> (Bag Diagnostic, Map Int TKind, Map Int TType)
solveConstrs = error "solveConstrs"
-substProg :: Map Name TType -> CProgram -> TProgram
+substProg :: Map Int CKind -- ^ Kind variable instantiations
+ -> Map Int CType -- ^ Type variable instantiations
+ -> CProgram
+ -> CProgram
substProg = error "substProg"
+substDdef :: Map Int CKind -> Map Int CType -> CDataDef -> CDataDef
+substDdef mk mt (DataDef () name pars cons) =
+ DataDef () name
+ (map (first (substKind mk)) pars)
+ (map (second (map (substType mk mt))) cons)
+
+substType :: Map Int CKind -> Map Int CType -> CType -> CType
+substType mk mt = \case
+ TApp k t ts -> TApp (substKind mk k) (substType mk mt t) (map (substType mk mt) ts)
+ TTup k ts -> TTup (substKind mk k) (map (substType mk mt) ts)
+ TList k t -> TList (substKind mk k) (substType mk mt t)
+ TFun k t1 t2 -> TFun (substKind mk k) (substType mk mt t1) (substType mk mt t2)
+ TCon k n -> TCon (substKind mk k) n
+ TVar k n -> TVar (substKind mk k) n
+ t@(TExt _ (TUniVar v)) -> fromMaybe t (Map.lookup v mt)
+
substKind :: Map Int CKind -> CKind -> CKind
-substKind _ k@KType{} = k
-substKind m (KFun () k1 k2) = KFun () (substKind m k1) (substKind m k2)
-substKind m k@(KExt () (KUniVar v)) = fromMaybe k (Map.lookup v m)
+substKind m = \case
+ KType () -> KType ()
+ KFun () k1 k2 -> KFun () (substKind m k1) (substKind m k2)
+ k@(KExt () (KUniVar v)) -> fromMaybe k (Map.lookup v m)
+
+doneProg :: Map Int TKind -- ^ Kind variable instantiations
+ -> Map Int TType -- ^ Type variable instantiations
+ -> CProgram
+ -> TProgram
+doneProg = error "doneProg"
kindUniVars :: CKind -> Set Int
kindUniVars = \case
diff --git a/src/HSVIS/Typecheck/Solve.hs b/src/HSVIS/Typecheck/Solve.hs
index 5f51abe..7250e79 100644
--- a/src/HSVIS/Typecheck/Solve.hs
+++ b/src/HSVIS/Typecheck/Solve.hs
@@ -9,7 +9,7 @@ module HSVIS.Typecheck.Solve (
import Control.Monad (guard, (>=>))
import Data.Bifunctor (Bifunctor(..))
import Data.Foldable (toList, foldl')
-import Data.List (sortBy, minimumBy, groupBy)
+import Data.List (sortBy, minimumBy, groupBy, intercalate)
import Data.Ord (comparing)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
@@ -66,8 +66,8 @@ solveConstraints
=> (t -> t -> Range -> (Bag (v, t, Range), Bag (t, t, Range)))
-- | Free variables in a type
-> (t -> Bag v)
- -- | \v repl term -> Substitute v by repl in term
- -> (v -> t -> t -> t)
+ -- | \mapping term -> term with variables in the mapping substituted by their values
+ -> (Map v t -> t -> t)
-- | Detect bare-variable types
-> (t -> Maybe v)
-- | Some kind of size measure on types
@@ -118,19 +118,22 @@ solveConstraints reduce frees subst detect size = \tupcs ->
$ Map.assocs m'
case msmallestvar of
- Nothing -> return $ applyLog eqlog mempty
+ Nothing -> do
+ traceM $ "[solver] Log = [" ++ intercalate ", " [show v ++ " = " ++ pretty t | (v, t) <- eqlog] ++ "]"
+ return $ applyLog eqlog mempty
Just (var, RConstr smallrhs _) -> do
+ traceM $ "[solver] Retiring " ++ show var ++ " = " ++ pretty smallrhs
let (_, otherrhss) = bagPartition (guard . (== smallrhs) . rconType) (m' Map.! var)
let (newcs, errs) = foldMap (reduce' . unsplitConstr smallrhs) (dedupRCs (toList otherrhss))
(fmap constrUnequal errs, ()) -- write the errors
let m'' = Map.unionWith (<>)
- (Map.map (fmap @Bag (fmap @RConstr (subst var smallrhs)))
+ (Map.map (fmap @Bag (fmap @RConstr (subst (Map.singleton var smallrhs))))
(Map.delete var m'))
(Map.fromListWith (<>) (map (second pure . splitConstr) (toList newcs)))
loop m'' ((var, smallrhs) : eqlog)
applyLog :: [(v, t)] -> Map v t -> Map v t
- applyLog ((v, t) : l) m = applyLog l $ Map.insert v t (Map.map (subst v t) m)
+ applyLog ((v, t) : l) m = applyLog l $ Map.insert v (subst m t) m
applyLog [] m = m
-- If there are multiple sources for the same cosntraint, only one of them is kept.