{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} module HSVIS.Typecheck.Solve ( solveConstraints, UnifyErr(..), ) where import Control.Monad (guard, (>=>)) import Data.Bifunctor (Bifunctor(..)) import Data.Foldable (toList, foldl') import Data.List (sortBy, minimumBy, groupBy) import Data.Ord (comparing) import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map import Debug.Trace import Data.Bag import HSVIS.Diagnostic (Range(..)) import HSVIS.Pretty import Data.Function (on) data UnifyErr v t = UEUnequal t t Range | UERecursive v t Range deriving (Show) data Constr a b = Constr a b Range deriving (Show) instance Bifunctor Constr where bimap f g (Constr x y r) = Constr (f x) (g y) r -- right-hand side of a constraint data RConstr b = RConstr b Range deriving (Show, Functor) splitConstr :: Constr a b -> (a, RConstr b) splitConstr (Constr x y r) = (x, RConstr y r) unsplitConstr :: a -> RConstr b -> Constr a b unsplitConstr x (RConstr y r) = Constr x y r constrUnequal :: Constr t t -> UnifyErr v t constrUnequal (Constr x y r) = UEUnequal x y r constrRecursive :: Constr v t -> UnifyErr v t constrRecursive (Constr x y r) = UERecursive x y r rconType :: RConstr b -> b rconType (RConstr t _) = t -- | Returns a pair of: -- 1. A set of unification errors; -- 2. An assignment of the variables that had any constraints on them. -- The producedure was successful if the set of errors is empty. Note that -- unconstrained variables do not appear in the output. solveConstraints :: forall v t. (Ord v, Ord t, Show v, Pretty t) -- | reduce: take two types and unify them, resulting in: -- 1. A bag of resulting constraints on variables; -- 2. A bag of errors: pairs of two types that are provably distinct but -- need to be equal for the input types to unify. => (t -> t -> Range -> (Bag (v, t, Range), Bag (t, t, Range))) -- | Free variables in a type -> (t -> Bag v) -- | \v repl term -> Substitute v by repl in term -> (v -> t -> t -> t) -- | Detect bare-variable types -> (t -> Maybe v) -- | Some kind of size measure on types -> (t -> Int) -- | Equality constraints to solve -> [(t, t, Range)] -> (Map v t, Bag (UnifyErr v t)) solveConstraints reduce frees subst detect size = \tupcs -> let cs = map (uncurry3 Constr) tupcs :: [Constr t t] (vcs, errs) = foldMap reduce' cs asg = Map.fromListWith (<>) (map (second pure . splitConstr) (toList vcs)) (errs', asg') = loop asg [] errs'' = fmap constrUnequal errs <> errs' in trace ("[solver] Solving:" ++ concat ["\n- " ++ pretty a ++ " == " ++ pretty b ++ " {" ++ pretty r ++ "}" | Constr a b r <- cs]) $ trace ("[solver] Result: (with " ++ show (length errs'') ++ " errors)" ++ concat ["\n- " ++ show v ++ " = " ++ pretty t | (v, t) <- Map.assocs asg']) (asg', errs'') where reduce' :: Constr t t -> (Bag (Constr v t), Bag (Constr t t)) reduce' (Constr t1 t2 r) = bimap (fmap (uncurry3 Constr)) (fmap (uncurry3 Constr)) $ reduce t1 t2 r loop :: Map v (Bag (RConstr t)) -> [(v, t)] -> (Bag (UnifyErr v t), Map v t) loop m eqlog = do traceM $ "[solver] Step:" ++ concat [case toList rhss of [] -> "\n- " ++ show v ++ " " RConstr t r : rest -> "\n- " ++ show v ++ " == " ++ pretty t ++ " {" ++ pretty r ++ "}" ++ concat ["\n " ++ replicate (length (show v)) ' ' ++ " == " ++ pretty t' ++ " {" ++ pretty r' ++ "}" | RConstr t' r' <- rest] | (v, rhss) <- Map.assocs m] m' <- Map.traverseWithKey (\v rhss -> let rhss' = bagFromList (dedupRCs (toList rhss)) -- filter out recursive equations (recs, nonrecs) = bagPartition (\c@(RConstr t _) -> if v `elem` frees t then Just c else Nothing) rhss' -- filter out trivial equations (v = v) (_, nonrecs') = bagPartition (detect . rconType >=> guard . (== v)) nonrecs in (constrRecursive . unsplitConstr v <$> recs, nonrecs')) m let msmallestvar :: Maybe (v, RConstr t) -- var with its smallest RHS, if such a var exists msmallestvar = minimumByMay (comparing (size . rconType . snd)) . map (second (minimumBy (comparing (size . rconType)))) . filter (not . null . snd) $ Map.assocs m' case msmallestvar of Nothing -> return $ applyLog eqlog mempty Just (var, RConstr smallrhs _) -> do let (_, otherrhss) = bagPartition (guard . (== smallrhs) . rconType) (m' Map.! var) let (newcs, errs) = foldMap (reduce' . unsplitConstr smallrhs) (dedupRCs (toList otherrhss)) (fmap constrUnequal errs, ()) -- write the errors let m'' = Map.unionWith (<>) (Map.map (fmap @Bag (fmap @RConstr (subst var smallrhs))) (Map.delete var m')) (Map.fromListWith (<>) (map (second pure . splitConstr) (toList newcs))) loop m'' ((var, smallrhs) : eqlog) applyLog :: [(v, t)] -> Map v t -> Map v t applyLog ((v, t) : l) m = applyLog l $ Map.insert v t (Map.map (subst v t) m) applyLog [] m = m -- If there are multiple sources for the same cosntraint, only one of them is kept. dedupRCs :: Ord t => [RConstr t] -> [RConstr t] dedupRCs = map head . groupBy ((==) `on` rconType) . sortBy (comparing rconType) minimumByMay :: Foldable t' => (a -> a -> Ordering) -> t' a -> Maybe a minimumByMay cmp = foldl' min' Nothing where min' mx y = Just $! case mx of Nothing -> y Just x | GT <- cmp x y -> y | otherwise -> x uncurry3 :: (a -> b -> c -> d) -> (a, b, c) -> d uncurry3 f (x, y, z) = f x y z