From cc61cdc000481f9dc88253342c328bdb99d048a4 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 17 Mar 2024 23:08:38 +0100 Subject: Typecheck work; solver is incorrect --- src/HSVIS/Typecheck/Solve.hs | 131 +++++++++++++++++++++++++++++-------------- 1 file changed, 88 insertions(+), 43 deletions(-) (limited to 'src/HSVIS/Typecheck/Solve.hs') diff --git a/src/HSVIS/Typecheck/Solve.hs b/src/HSVIS/Typecheck/Solve.hs index 184937c..5f51abe 100644 --- a/src/HSVIS/Typecheck/Solve.hs +++ b/src/HSVIS/Typecheck/Solve.hs @@ -1,10 +1,15 @@ +{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE ScopedTypeVariables #-} -module HSVIS.Typecheck.Solve where +{-# LANGUAGE TypeApplications #-} +module HSVIS.Typecheck.Solve ( + solveConstraints, + UnifyErr(..), +) where import Control.Monad (guard, (>=>)) -import Data.Bifunctor (second) +import Data.Bifunctor (Bifunctor(..)) import Data.Foldable (toList, foldl') -import Data.List (sort) +import Data.List (sortBy, minimumBy, groupBy) import Data.Ord (comparing) import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map @@ -12,26 +17,53 @@ import qualified Data.Map.Strict as Map import Debug.Trace import Data.Bag -import Data.List (minimumBy) +import HSVIS.Diagnostic (Range(..)) +import HSVIS.Pretty +import Data.Function (on) data UnifyErr v t - = UEUnequal t t - | UERecursive v t + = UEUnequal t t Range + | UERecursive v t Range deriving (Show) +data Constr a b = Constr a b Range + deriving (Show) + +instance Bifunctor Constr where + bimap f g (Constr x y r) = Constr (f x) (g y) r + +-- right-hand side of a constraint +data RConstr b = RConstr b Range + deriving (Show, Functor) + +splitConstr :: Constr a b -> (a, RConstr b) +splitConstr (Constr x y r) = (x, RConstr y r) + +unsplitConstr :: a -> RConstr b -> Constr a b +unsplitConstr x (RConstr y r) = Constr x y r + +constrUnequal :: Constr t t -> UnifyErr v t +constrUnequal (Constr x y r) = UEUnequal x y r + +constrRecursive :: Constr v t -> UnifyErr v t +constrRecursive (Constr x y r) = UERecursive x y r + +rconType :: RConstr b -> b +rconType (RConstr t _) = t + -- | Returns a pair of: -- 1. A set of unification errors; -- 2. An assignment of the variables that had any constraints on them. -- The producedure was successful if the set of errors is empty. Note that -- unconstrained variables do not appear in the output. solveConstraints - :: forall v t. (Ord v, Ord t, Show v, Show t) + :: forall v t. (Ord v, Ord t, Show v, Pretty t) -- | reduce: take two types and unify them, resulting in: -- 1. A bag of resulting constraints on variables; -- 2. A bag of errors: pairs of two types that are provably distinct but -- need to be equal for the input types to unify. - => (t -> t -> (Bag (v, t), Bag (t, t))) + => (t -> t -> Range -> (Bag (v, t, Range), Bag (t, t, Range))) -- | Free variables in a type -> (t -> Bag v) -- | \v repl term -> Substitute v by repl in term @@ -41,63 +73,76 @@ solveConstraints -- | Some kind of size measure on types -> (t -> Int) -- | Equality constraints to solve - -> [(t, t)] + -> [(t, t, Range)] -> (Map v t, Bag (UnifyErr v t)) -solveConstraints reduce frees subst detect size = \cs -> - let (vcs, errs) = foldMap (uncurry reduce) cs - asg = Map.fromListWith (<>) (map (second pure) (toList vcs)) +solveConstraints reduce frees subst detect size = \tupcs -> + let cs = map (uncurry3 Constr) tupcs :: [Constr t t] + (vcs, errs) = foldMap reduce' cs + asg = Map.fromListWith (<>) (map (second pure . splitConstr) (toList vcs)) (errs', asg') = loop asg [] - errs'' = fmap (uncurry UEUnequal) errs <> errs' - in trace ("[solver] Solving:" ++ concat ["\n- " ++ show a ++ " == " ++ show b | (a, b) <- cs]) $ + errs'' = fmap constrUnequal errs <> errs' + in trace ("[solver] Solving:" ++ concat ["\n- " ++ pretty a ++ " == " ++ pretty b ++ " {" ++ pretty r ++ "}" | Constr a b r <- cs]) $ trace ("[solver] Result: (with " ++ show (length errs'') ++ " errors)" ++ - concat ["\n- " ++ show v ++ " = " ++ show t | (v, t) <- Map.assocs asg']) + concat ["\n- " ++ show v ++ " = " ++ pretty t | (v, t) <- Map.assocs asg']) (asg', errs'') where - loop :: Map v (Bag t) -> [(v, t)] -> (Bag (UnifyErr v t), Map v t) + reduce' :: Constr t t -> (Bag (Constr v t), Bag (Constr t t)) + reduce' (Constr t1 t2 r) = bimap (fmap (uncurry3 Constr)) (fmap (uncurry3 Constr)) $ reduce t1 t2 r + + loop :: Map v (Bag (RConstr t)) -> [(v, t)] -> (Bag (UnifyErr v t), Map v t) loop m eqlog = do - traceM $ "[solver] Step:" ++ concat ["\n- " ++ show v ++ " == " ++ show t | (v, t) <- Map.assocs m] + traceM $ "[solver] Step:" ++ concat + [case toList rhss of + [] -> "\n- " ++ show v ++ " " + RConstr t r : rest -> + "\n- " ++ show v ++ " == " ++ pretty t ++ " {" ++ pretty r ++ "}" ++ + concat ["\n " ++ replicate (length (show v)) ' ' ++ " == " ++ pretty t' ++ " {" ++ pretty r' ++ "}" + | RConstr t' r' <- rest] + | (v, rhss) <- Map.assocs m] + m' <- Map.traverseWithKey - (\v ts -> - let ts' = bagFromList (dedup (toList ts)) + (\v rhss -> + let rhss' = bagFromList (dedupRCs (toList rhss)) -- filter out recursive equations - (recs, nonrecs) = bagPartition (\t -> if v `elem` frees t then Just t else Nothing) ts' + (recs, nonrecs) = bagPartition (\c@(RConstr t _) -> if v `elem` frees t then Just c else Nothing) rhss' -- filter out trivial equations (v = v) - (_, nonrecs') = bagPartition (detect >=> guard . (== v)) nonrecs - in (UERecursive v <$> recs, nonrecs')) + (_, nonrecs') = bagPartition (detect . rconType >=> guard . (== v)) nonrecs + in (constrRecursive . unsplitConstr v <$> recs, nonrecs')) m - let msmallestvar = -- var with its smallest RHS, if such a var exists - minimumByMay (comparing (size . snd)) - . map (second (minimumBy (comparing size))) + let msmallestvar :: Maybe (v, RConstr t) -- var with its smallest RHS, if such a var exists + msmallestvar = + minimumByMay (comparing (size . rconType . snd)) + . map (second (minimumBy (comparing (size . rconType)))) . filter (not . null . snd) $ Map.assocs m' case msmallestvar of Nothing -> return $ applyLog eqlog mempty - Just (var, smallrhs) -> do - let (_, otherrhss) = bagPartition (guard . (== smallrhs)) (m' Map.! var) - let (newcs, errs) = foldMap (reduce smallrhs) (dedup (toList otherrhss)) - (fmap (uncurry UEUnequal) errs, ()) -- write the errors + Just (var, RConstr smallrhs _) -> do + let (_, otherrhss) = bagPartition (guard . (== smallrhs) . rconType) (m' Map.! var) + let (newcs, errs) = foldMap (reduce' . unsplitConstr smallrhs) (dedupRCs (toList otherrhss)) + (fmap constrUnequal errs, ()) -- write the errors let m'' = Map.unionWith (<>) - (Map.map (fmap (subst var smallrhs)) (Map.delete var m')) - (Map.fromListWith (<>) (map (second pure) (toList newcs))) + (Map.map (fmap @Bag (fmap @RConstr (subst var smallrhs))) + (Map.delete var m')) + (Map.fromListWith (<>) (map (second pure . splitConstr) (toList newcs))) loop m'' ((var, smallrhs) : eqlog) applyLog :: [(v, t)] -> Map v t -> Map v t applyLog ((v, t) : l) m = applyLog l $ Map.insert v t (Map.map (subst v t) m) applyLog [] m = m - dedup :: Ord t => [t] -> [t] - dedup = uniq . sort + -- If there are multiple sources for the same cosntraint, only one of them is kept. + dedupRCs :: Ord t => [RConstr t] -> [RConstr t] + dedupRCs = map head . groupBy ((==) `on` rconType) . sortBy (comparing rconType) - uniq :: Eq a => [a] -> [a] - uniq (x:y:xs) | x == y = uniq (x : xs) - | otherwise = x : uniq (y : xs) - uniq l = l +minimumByMay :: Foldable t' => (a -> a -> Ordering) -> t' a -> Maybe a +minimumByMay cmp = foldl' min' Nothing + where min' mx y = Just $! case mx of + Nothing -> y + Just x | GT <- cmp x y -> y + | otherwise -> x - minimumByMay :: Foldable t' => (a -> a -> Ordering) -> t' a -> Maybe a - minimumByMay cmp = foldl' min' Nothing - where min' mx y = Just $! case mx of - Nothing -> y - Just x | GT <- cmp x y -> y - | otherwise -> x +uncurry3 :: (a -> b -> c -> d) -> (a, b, c) -> d +uncurry3 f (x, y, z) = f x y z -- cgit v1.2.3-70-g09d2