{-# LANGUAGE ScopedTypeVariables #-} module HSVIS.Typecheck.Solve where import Control.Monad (guard, (>=>)) import Data.Bifunctor (second) import Data.Foldable (toList, foldl') import Data.List (sort) import Data.Ord (comparing) import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map import Debug.Trace import Data.Bag import Data.List (minimumBy) data UnifyErr v t = UEUnequal t t | UERecursive v t deriving (Show) -- | Returns a pair of: -- 1. A set of unification errors; -- 2. An assignment of the variables that had any constraints on them. -- The producedure was successful if the set of errors is empty. Note that -- unconstrained variables do not appear in the output. solveConstraints :: forall v t. (Ord v, Ord t, Show v, Show t) -- | reduce: take two types and unify them, resulting in: -- 1. A bag of resulting constraints on variables; -- 2. A bag of errors: pairs of two types that are provably distinct but -- need to be equal for the input types to unify. => (t -> t -> (Bag (v, t), Bag (t, t))) -- | Free variables in a type -> (t -> Bag v) -- | \v repl term -> Substitute v by repl in term -> (v -> t -> t -> t) -- | Detect bare-variable types -> (t -> Maybe v) -- | Some kind of size measure on types -> (t -> Int) -- | Equality constraints to solve -> [(t, t)] -> (Map v t, Bag (UnifyErr v t)) solveConstraints reduce frees subst detect size = \cs -> let (vcs, errs) = foldMap (uncurry reduce) cs asg = Map.fromListWith (<>) (map (second pure) (toList vcs)) (errs', asg') = loop asg [] errs'' = fmap (uncurry UEUnequal) errs <> errs' in trace ("[solver] Solving:" ++ concat ["\n- " ++ show a ++ " == " ++ show b | (a, b) <- cs]) $ trace ("[solver] Result: (with " ++ show (length errs'') ++ " errors)" ++ concat ["\n- " ++ show v ++ " = " ++ show t | (v, t) <- Map.assocs asg']) (asg', errs'') where loop :: Map v (Bag t) -> [(v, t)] -> (Bag (UnifyErr v t), Map v t) loop m eqlog = do traceM $ "[solver] Step:" ++ concat ["\n- " ++ show v ++ " == " ++ show t | (v, t) <- Map.assocs m] m' <- Map.traverseWithKey (\v ts -> let ts' = bagFromList (dedup (toList ts)) -- filter out recursive equations (recs, nonrecs) = bagPartition (\t -> if v `elem` frees t then Just t else Nothing) ts' -- filter out trivial equations (v = v) (_, nonrecs') = bagPartition (detect >=> guard . (== v)) nonrecs in (UERecursive v <$> recs, nonrecs')) m let msmallestvar = -- var with its smallest RHS, if such a var exists minimumByMay (comparing (size . snd)) . map (second (minimumBy (comparing size))) . filter (not . null . snd) $ Map.assocs m' case msmallestvar of Nothing -> return $ applyLog eqlog mempty Just (var, smallrhs) -> do let (_, otherrhss) = bagPartition (guard . (== smallrhs)) (m' Map.! var) let (newcs, errs) = foldMap (reduce smallrhs) (dedup (toList otherrhss)) (fmap (uncurry UEUnequal) errs, ()) -- write the errors let m'' = Map.unionWith (<>) (Map.map (fmap (subst var smallrhs)) (Map.delete var m')) (Map.fromListWith (<>) (map (second pure) (toList newcs))) loop m'' ((var, smallrhs) : eqlog) applyLog :: [(v, t)] -> Map v t -> Map v t applyLog ((v, t) : l) m = applyLog l $ Map.insert v t (Map.map (subst v t) m) applyLog [] m = m dedup :: Ord t => [t] -> [t] dedup = uniq . sort uniq :: Eq a => [a] -> [a] uniq (x:y:xs) | x == y = uniq (x : xs) | otherwise = x : uniq (y : xs) uniq l = l minimumByMay :: Foldable t' => (a -> a -> Ordering) -> t' a -> Maybe a minimumByMay cmp = foldl' min' Nothing where min' mx y = Just $! case mx of Nothing -> y Just x | GT <- cmp x y -> y | otherwise -> x