diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-03-17 23:08:38 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-03-17 23:08:52 +0100 |
commit | cc61cdc000481f9dc88253342c328bdb99d048a4 (patch) | |
tree | d1959086d000b3e54a9e45a7f309206e2a24b958 | |
parent | e7bed242ba52e6d3233928f2c6189e701cfa5e4c (diff) |
Typecheck work; solver is incorrect
-rw-r--r-- | examples/test-kinds.hs | 17 | ||||
-rw-r--r-- | src/HSVIS/AST.hs | 38 | ||||
-rw-r--r-- | src/HSVIS/Diagnostic.hs | 23 | ||||
-rw-r--r-- | src/HSVIS/Parser.hs | 10 | ||||
-rw-r--r-- | src/HSVIS/Typecheck.hs | 83 | ||||
-rw-r--r-- | src/HSVIS/Typecheck/Solve.hs | 131 |
6 files changed, 214 insertions, 88 deletions
diff --git a/examples/test-kinds.hs b/examples/test-kinds.hs new file mode 100644 index 0000000..1e2c18c --- /dev/null +++ b/examples/test-kinds.hs @@ -0,0 +1,17 @@ +data Tree a + = Node (Tree a) a (Tree a) + | Leaf + +data A f = A1 (f ()) | A2 (f (Tree ())) + +data Either a b = Left a | Right b + +data ExceptT e m a = ExceptT (Either e (m a)) + +data TreeF a r + = NodeF r a r + | LeafF + +data Fix f = In (f (Fix f)) + +data Tree' a = Tree' (Fix (TreeF a)) diff --git a/src/HSVIS/AST.hs b/src/HSVIS/AST.hs index 8bb2d6c..2b125b9 100644 --- a/src/HSVIS/AST.hs +++ b/src/HSVIS/AST.hs @@ -11,6 +11,7 @@ module HSVIS.AST where import Data.Bifunctor (bimap, first, second) import qualified Data.Kind as DK +import Data.List (intersperse) import Data.List.NonEmpty (NonEmpty) import Data.Proxy @@ -136,13 +137,36 @@ data Operator = OAdd | OSub | OMul | ODiv | OMod | OEqu | OPow deriving (Show) instance Pretty Name where - prettysPrec _ (Name n) = showString ("\"" ++ n ++ "\"") - -instance (X Kind s ~ (), Pretty (E Kind s)) => Pretty (Kind s) where - prettysPrec _ (KType ()) = showString "Type" - prettysPrec d (KFun () a b) = - showParen (d > -1) $ prettysPrec 0 a . showString " -> " . prettysPrec 0 b - prettysPrec d (KExt () e) = prettysPrec d e + prettysPrec _ (Name n) = showString n + +instance Pretty (E Kind s) => Pretty (Kind s) where + prettysPrec _ (KType _) = showString "Type" + prettysPrec d (KFun _ a b) = showParen (d > -1) $ + prettysPrec 0 a . showString " -> " . prettysPrec (-1) b + prettysPrec d (KExt _ e) = prettysPrec d e + +instance Pretty (E Type s) => Pretty (Type s) where + prettysPrec d (TApp _ a ts) = showParen (d > 10) $ + prettysPrec 10 a . foldr (.) id [showString " " . prettysPrec 11 t | t <- ts] + prettysPrec _ (TTup _ ts) = + showString "(" . foldr (.) id (intersperse (showString ",") (map (prettysPrec 0) ts)) . showString ")" + prettysPrec _ (TList _ t) = + showString "[" . prettysPrec 0 t . showString "]" + prettysPrec d (TFun _ a b) = showParen (d > -1) $ + prettysPrec 0 a . showString " -> " . prettysPrec (-1) b + prettysPrec _ (TCon _ n) = prettysPrec 11 n + prettysPrec _ (TVar _ n) = prettysPrec 11 n + prettysPrec d (TExt _ e) = prettysPrec d e + +instance (Pretty (X Type s), Pretty (E Type s)) => Pretty (DataDef s) where + prettysPrec _ (DataDef _ name vars cons) = + showString "data " . prettysPrec 11 name + . foldr (.) id [showString " (" . prettysPrec 11 n . showString " :: " . prettysPrec 11 k . showString ")" + | (k, n) <- vars] + . showString " = " + . foldr (.) id (intersperse (showString " | ") + [prettysPrec 11 cname . foldr (.) id [showString " " . prettysPrec 11 t | t <- fields] + | (cname, fields) <- cons]) instance HasExt DataDef where type HasXField DataDef = 'True diff --git a/src/HSVIS/Diagnostic.hs b/src/HSVIS/Diagnostic.hs index 322f9eb..675482d 100644 --- a/src/HSVIS/Diagnostic.hs +++ b/src/HSVIS/Diagnostic.hs @@ -2,6 +2,8 @@ module HSVIS.Diagnostic where import Data.List (intercalate) +import HSVIS.Pretty + data Pos = Pos { posLine :: Int -- ^ zero-based @@ -9,6 +11,9 @@ data Pos = Pos } deriving (Show, Eq, Ord) +instance Pretty Pos where + prettysPrec _ (Pos y x) = showString (show (y + 1) ++ ":" ++ show (x + 1)) + -- | Inclusive-exclusive range of positions in a file. data Range = Range Pos Pos deriving (Show) @@ -16,8 +21,15 @@ data Range = Range Pos Pos instance Semigroup Range where Range a b <> Range c d = Range (min a c) (max b d) -data Diagnostic = Diagnostic - { dFile :: FilePath -- ^ The file for which the diagnostic was raised +instance Pretty Range where + prettysPrec _ (Range (Pos y1 x1) (Pos y2 x2)) + | y2 <= y1 + 1, x2 <= x1 + 1 = showString (show (y1 + 1) ++ ":" ++ show (x1 + 1)) + | y2 <= y1 + 1 = showString (show (y1 + 1) ++ ":" ++ show (x1 + 1) ++ "-" ++ show x2) + | otherwise = + showString ("(" ++ show (y1 + 1) ++ ":" ++ show (x1 + 1) ++ ")-(" ++ show (y2 + 1) ++ ":" ++ show x2 ++ ")") + +data Diagnostic = Diagnostic + { dFile :: FilePath -- ^ The file for which the diagnostic was rai sed , dRange :: Range -- ^ Where in the file , dStk :: [String] -- ^ Stack of contexts (innermost at head) of the diagnostic , dSourceLine :: String -- ^ The line in the source file of the start of the range @@ -26,12 +38,9 @@ data Diagnostic = Diagnostic deriving (Show) printDiagnostic :: Diagnostic -> String -printDiagnostic (Diagnostic fp (Range (Pos y1 x1) (Pos y2 x2)) stk srcline msg) = +printDiagnostic (Diagnostic fp rng@(Range (Pos y1 x1) (Pos y2 x2)) stk srcline msg) = let linenum = show (y1 + 1) - locstr | y1 == y2, x1 == x2 = show y1 ++ ":" ++ show x1 - | y1 == y2 = show y1 ++ ":" ++ show x1 ++ "-" ++ show x2 - | otherwise = "(" ++ show y1 ++ ":" ++ show x1 ++ ")-(" ++ - show y1 ++ ":" ++ show x1 ++ ")" + locstr = pretty rng ncarets | y1 == y2 = max 1 (x2 - x1 + 1) | otherwise = length srcline - x1 caretsuffix | y1 == y2 = "" diff --git a/src/HSVIS/Parser.hs b/src/HSVIS/Parser.hs index 0df4aa8..b4d8754 100644 --- a/src/HSVIS/Parser.hs +++ b/src/HSVIS/Parser.hs @@ -672,9 +672,13 @@ pType = do return (TFun (Range pos1 pos2) ty1 ty2) pTypeApp :: FParser PType -pTypeApp = fasome pTypeAtom >>= \case - t :| [] -> return t - t :| ts -> return (TApp (foldMapne extOf (t :| ts)) t ts) +pTypeApp = do + pos1 <- gets psCur + ts <- fasome pTypeAtom + pos2 <- gets psCur + case ts of + t :| [] -> return t + t :| ts' -> return (TApp (Range pos1 pos2) t ts') pTypeAtom :: FParser PType pTypeAtom = faasum' [pTypeParens, pTypeList, pTypeName] diff --git a/src/HSVIS/Typecheck.hs b/src/HSVIS/Typecheck.hs index ba853a0..c97064a 100644 --- a/src/HSVIS/Typecheck.hs +++ b/src/HSVIS/Typecheck.hs @@ -78,6 +78,9 @@ type TPattern = Pattern StageTyped type TRHS = RHS StageTyped type TExpr = Expr StageTyped +instance Pretty (E Type StageTC) where + prettysPrec _ (TUniVar n) = showString ("?t" ++ show n) + instance Pretty (E Kind StageTC) where prettysPrec _ (KUniVar n) = showString ("?k" ++ show n) @@ -144,7 +147,7 @@ putFullEnv :: Env -> TCM () putFullEnv env = TCM $ \_ i _ -> (mempty, mempty, i, env, ()) genId :: TCM Int -genId = TCM $ \_ i env -> (mempty, mempty, i, env, i) +genId = TCM $ \_ i env -> (mempty, mempty, i + 1, env, i) getKind :: Name -> TCM (Maybe CKind) getKind name = do @@ -208,6 +211,8 @@ tcProgram (Program ddefs fdefs) = do solveKindVars kconstrs + traceM (unlines (map pretty ddefs')) + fdefs' <- mapM tcFunDef fdefs return (Program ddefs' fdefs') @@ -232,40 +237,59 @@ tcDataDef (DataDef rng name params cons) = do cons' <- scopeTEnv $ do modifyTEnv (Map.fromList (zip (map snd params) pkinds) <>) - mapM (\(cname, ty) -> (cname,) <$> mapM kcType ty) cons + mapM (\(cname, fieldtys) -> (cname,) <$> mapM (kcType (Just (KType ()))) fieldtys) cons return (DataDef () name (zip pkinds (map snd params)) cons') -kcType :: PType -> TCM CType -kcType = \case +promoteDown :: Maybe CKind -> TCM CKind +promoteDown Nothing = genKUniVar +promoteDown (Just k) = return k + +downEqK :: Range -> Maybe CKind -> CKind -> TCM () +downEqK _ Nothing _ = return () +downEqK rng (Just k1) k2 = emit $ CEqK k1 k2 rng + +-- | Given (maybe) the expected kind of this type, and a type, check it for +-- kind-correctness. +kcType :: Maybe CKind -> PType -> TCM CType +kcType mdown = \case TApp rng t ts -> do - t' <- kcType t - ts' <- mapM kcType ts - retk <- genKUniVar + t' <- kcType Nothing t + ts' <- mapM (kcType Nothing) ts + retk <- promoteDown mdown let expected = foldr (KFun ()) retk (map extOf ts') emit $ CEqK (extOf t') expected rng return (TApp retk t' ts') - TTup _ ts -> do - ts' <- mapM kcType ts + TTup rng ts -> do + ts' <- mapM (kcType (Just (KType ()))) ts forM_ (zip (map extOf ts) ts') $ \(trng, ct) -> emit $ CEqK (extOf ct) (KType ()) trng + downEqK rng mdown (KType ()) return (TTup (KType ()) ts') - TList _ t -> do - t' <- kcType t + TList rng t -> do + t' <- kcType (Just (KType ())) t emit $ CEqK (extOf t') (KType ()) (extOf t) + downEqK rng mdown (KType ()) return (TList (KType ()) t') - TFun _ t1 t2 -> do - t1' <- kcType t1 - t2' <- kcType t2 + TFun rng t1 t2 -> do + t1' <- kcType (Just (KType ())) t1 + t2' <- kcType (Just (KType ())) t2 emit $ CEqK (extOf t1') (KType ()) (extOf t1) emit $ CEqK (extOf t2') (KType ()) (extOf t2) + downEqK rng mdown (KType ()) return (TFun (KType ()) t1' t2') - TCon rng n -> TCon <$> getKind' rng n <*> pure n + TCon rng n -> do + k <- getKind' rng n + downEqK rng mdown k + return (TCon k n) - TVar rng n -> TVar <$> getKind' rng n <*> pure n + TVar rng n -> do + k <- getKind' rng n + downEqK rng mdown k + return (TVar k n) tcFunDef :: PFunDef -> TCM CFunDef tcFunDef (FunDef _ name msig eqs) = do @@ -273,7 +297,7 @@ tcFunDef (FunDef _ name msig eqs) = do raise (sconcatne (fmap extOf eqs)) "Function equations have differing numbers of arguments" typ <- case msig of - TypeSig sig -> kcType sig + TypeSig sig -> kcType (Just (KType ())) sig TypeSigExt NoTypeSig -> genUniVar (KType ()) eqs' <- mapM (tcFunEq typ) eqs @@ -326,18 +350,21 @@ solveKindVars cs = do (\case KExt () (KUniVar v) -> Just v _ -> Nothing) kindSize - (map (\(a, b, _) -> (a, b)) (toList cs)) + (toList cs) where - reduce :: CKind -> CKind -> (Bag (Int, CKind), Bag (CKind, CKind)) - -- unification variables produce constraints on a unification variable - reduce (KExt () (KUniVar i)) (KExt () (KUniVar j)) | i == j = mempty - reduce (KExt () (KUniVar i)) k = (pure (i, k), mempty) - reduce k (KExt () (KUniVar i)) = (pure (i, k), mempty) - -- if lhs and rhs have equal prefixes, recurse - reduce (KType ()) (KType ()) = mempty - reduce (KFun () a b) (KFun () c d) = reduce a c <> reduce b d - -- otherwise, this is a kind mismatch - reduce k1 k2 = (mempty, pure (k1, k2)) + reduce :: CKind -> CKind -> Range -> (Bag (Int, CKind, Range), Bag (CKind, CKind, Range)) + reduce lhs rhs rng = case (lhs, rhs) of + -- unification variables produce constraints on a unification variable + (KExt () (KUniVar i), KExt () (KUniVar j)) | i == j -> mempty + (KExt () (KUniVar i), k ) -> (pure (i, k, rng), mempty) + (k , KExt () (KUniVar i)) -> (pure (i, k, rng), mempty) + + -- if lhs and rhs have equal prefixes, recurse + (KType () , KType () ) -> mempty + (KFun () a b, KFun () c d) -> reduce a c rng <> reduce b d rng + + -- otherwise, this is a kind mismatch + (k1, k2) -> (mempty, pure (k1, k2, rng)) kindSize :: CKind -> Int kindSize KType{} = 1 diff --git a/src/HSVIS/Typecheck/Solve.hs b/src/HSVIS/Typecheck/Solve.hs index 184937c..5f51abe 100644 --- a/src/HSVIS/Typecheck/Solve.hs +++ b/src/HSVIS/Typecheck/Solve.hs @@ -1,10 +1,15 @@ +{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE ScopedTypeVariables #-} -module HSVIS.Typecheck.Solve where +{-# LANGUAGE TypeApplications #-} +module HSVIS.Typecheck.Solve ( + solveConstraints, + UnifyErr(..), +) where import Control.Monad (guard, (>=>)) -import Data.Bifunctor (second) +import Data.Bifunctor (Bifunctor(..)) import Data.Foldable (toList, foldl') -import Data.List (sort) +import Data.List (sortBy, minimumBy, groupBy) import Data.Ord (comparing) import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map @@ -12,26 +17,53 @@ import qualified Data.Map.Strict as Map import Debug.Trace import Data.Bag -import Data.List (minimumBy) +import HSVIS.Diagnostic (Range(..)) +import HSVIS.Pretty +import Data.Function (on) data UnifyErr v t - = UEUnequal t t - | UERecursive v t + = UEUnequal t t Range + | UERecursive v t Range deriving (Show) +data Constr a b = Constr a b Range + deriving (Show) + +instance Bifunctor Constr where + bimap f g (Constr x y r) = Constr (f x) (g y) r + +-- right-hand side of a constraint +data RConstr b = RConstr b Range + deriving (Show, Functor) + +splitConstr :: Constr a b -> (a, RConstr b) +splitConstr (Constr x y r) = (x, RConstr y r) + +unsplitConstr :: a -> RConstr b -> Constr a b +unsplitConstr x (RConstr y r) = Constr x y r + +constrUnequal :: Constr t t -> UnifyErr v t +constrUnequal (Constr x y r) = UEUnequal x y r + +constrRecursive :: Constr v t -> UnifyErr v t +constrRecursive (Constr x y r) = UERecursive x y r + +rconType :: RConstr b -> b +rconType (RConstr t _) = t + -- | Returns a pair of: -- 1. A set of unification errors; -- 2. An assignment of the variables that had any constraints on them. -- The producedure was successful if the set of errors is empty. Note that -- unconstrained variables do not appear in the output. solveConstraints - :: forall v t. (Ord v, Ord t, Show v, Show t) + :: forall v t. (Ord v, Ord t, Show v, Pretty t) -- | reduce: take two types and unify them, resulting in: -- 1. A bag of resulting constraints on variables; -- 2. A bag of errors: pairs of two types that are provably distinct but -- need to be equal for the input types to unify. - => (t -> t -> (Bag (v, t), Bag (t, t))) + => (t -> t -> Range -> (Bag (v, t, Range), Bag (t, t, Range))) -- | Free variables in a type -> (t -> Bag v) -- | \v repl term -> Substitute v by repl in term @@ -41,63 +73,76 @@ solveConstraints -- | Some kind of size measure on types -> (t -> Int) -- | Equality constraints to solve - -> [(t, t)] + -> [(t, t, Range)] -> (Map v t, Bag (UnifyErr v t)) -solveConstraints reduce frees subst detect size = \cs -> - let (vcs, errs) = foldMap (uncurry reduce) cs - asg = Map.fromListWith (<>) (map (second pure) (toList vcs)) +solveConstraints reduce frees subst detect size = \tupcs -> + let cs = map (uncurry3 Constr) tupcs :: [Constr t t] + (vcs, errs) = foldMap reduce' cs + asg = Map.fromListWith (<>) (map (second pure . splitConstr) (toList vcs)) (errs', asg') = loop asg [] - errs'' = fmap (uncurry UEUnequal) errs <> errs' - in trace ("[solver] Solving:" ++ concat ["\n- " ++ show a ++ " == " ++ show b | (a, b) <- cs]) $ + errs'' = fmap constrUnequal errs <> errs' + in trace ("[solver] Solving:" ++ concat ["\n- " ++ pretty a ++ " == " ++ pretty b ++ " {" ++ pretty r ++ "}" | Constr a b r <- cs]) $ trace ("[solver] Result: (with " ++ show (length errs'') ++ " errors)" ++ - concat ["\n- " ++ show v ++ " = " ++ show t | (v, t) <- Map.assocs asg']) + concat ["\n- " ++ show v ++ " = " ++ pretty t | (v, t) <- Map.assocs asg']) (asg', errs'') where - loop :: Map v (Bag t) -> [(v, t)] -> (Bag (UnifyErr v t), Map v t) + reduce' :: Constr t t -> (Bag (Constr v t), Bag (Constr t t)) + reduce' (Constr t1 t2 r) = bimap (fmap (uncurry3 Constr)) (fmap (uncurry3 Constr)) $ reduce t1 t2 r + + loop :: Map v (Bag (RConstr t)) -> [(v, t)] -> (Bag (UnifyErr v t), Map v t) loop m eqlog = do - traceM $ "[solver] Step:" ++ concat ["\n- " ++ show v ++ " == " ++ show t | (v, t) <- Map.assocs m] + traceM $ "[solver] Step:" ++ concat + [case toList rhss of + [] -> "\n- " ++ show v ++ " <no RHSs>" + RConstr t r : rest -> + "\n- " ++ show v ++ " == " ++ pretty t ++ " {" ++ pretty r ++ "}" ++ + concat ["\n " ++ replicate (length (show v)) ' ' ++ " == " ++ pretty t' ++ " {" ++ pretty r' ++ "}" + | RConstr t' r' <- rest] + | (v, rhss) <- Map.assocs m] + m' <- Map.traverseWithKey - (\v ts -> - let ts' = bagFromList (dedup (toList ts)) + (\v rhss -> + let rhss' = bagFromList (dedupRCs (toList rhss)) -- filter out recursive equations - (recs, nonrecs) = bagPartition (\t -> if v `elem` frees t then Just t else Nothing) ts' + (recs, nonrecs) = bagPartition (\c@(RConstr t _) -> if v `elem` frees t then Just c else Nothing) rhss' -- filter out trivial equations (v = v) - (_, nonrecs') = bagPartition (detect >=> guard . (== v)) nonrecs - in (UERecursive v <$> recs, nonrecs')) + (_, nonrecs') = bagPartition (detect . rconType >=> guard . (== v)) nonrecs + in (constrRecursive . unsplitConstr v <$> recs, nonrecs')) m - let msmallestvar = -- var with its smallest RHS, if such a var exists - minimumByMay (comparing (size . snd)) - . map (second (minimumBy (comparing size))) + let msmallestvar :: Maybe (v, RConstr t) -- var with its smallest RHS, if such a var exists + msmallestvar = + minimumByMay (comparing (size . rconType . snd)) + . map (second (minimumBy (comparing (size . rconType)))) . filter (not . null . snd) $ Map.assocs m' case msmallestvar of Nothing -> return $ applyLog eqlog mempty - Just (var, smallrhs) -> do - let (_, otherrhss) = bagPartition (guard . (== smallrhs)) (m' Map.! var) - let (newcs, errs) = foldMap (reduce smallrhs) (dedup (toList otherrhss)) - (fmap (uncurry UEUnequal) errs, ()) -- write the errors + Just (var, RConstr smallrhs _) -> do + let (_, otherrhss) = bagPartition (guard . (== smallrhs) . rconType) (m' Map.! var) + let (newcs, errs) = foldMap (reduce' . unsplitConstr smallrhs) (dedupRCs (toList otherrhss)) + (fmap constrUnequal errs, ()) -- write the errors let m'' = Map.unionWith (<>) - (Map.map (fmap (subst var smallrhs)) (Map.delete var m')) - (Map.fromListWith (<>) (map (second pure) (toList newcs))) + (Map.map (fmap @Bag (fmap @RConstr (subst var smallrhs))) + (Map.delete var m')) + (Map.fromListWith (<>) (map (second pure . splitConstr) (toList newcs))) loop m'' ((var, smallrhs) : eqlog) applyLog :: [(v, t)] -> Map v t -> Map v t applyLog ((v, t) : l) m = applyLog l $ Map.insert v t (Map.map (subst v t) m) applyLog [] m = m - dedup :: Ord t => [t] -> [t] - dedup = uniq . sort + -- If there are multiple sources for the same cosntraint, only one of them is kept. + dedupRCs :: Ord t => [RConstr t] -> [RConstr t] + dedupRCs = map head . groupBy ((==) `on` rconType) . sortBy (comparing rconType) - uniq :: Eq a => [a] -> [a] - uniq (x:y:xs) | x == y = uniq (x : xs) - | otherwise = x : uniq (y : xs) - uniq l = l +minimumByMay :: Foldable t' => (a -> a -> Ordering) -> t' a -> Maybe a +minimumByMay cmp = foldl' min' Nothing + where min' mx y = Just $! case mx of + Nothing -> y + Just x | GT <- cmp x y -> y + | otherwise -> x - minimumByMay :: Foldable t' => (a -> a -> Ordering) -> t' a -> Maybe a - minimumByMay cmp = foldl' min' Nothing - where min' mx y = Just $! case mx of - Nothing -> y - Just x | GT <- cmp x y -> y - | otherwise -> x +uncurry3 :: (a -> b -> c -> d) -> (a, b, c) -> d +uncurry3 f (x, y, z) = f x y z |