aboutsummaryrefslogtreecommitdiff
path: root/src/HSVIS/Typecheck/Solve.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/HSVIS/Typecheck/Solve.hs')
-rw-r--r--src/HSVIS/Typecheck/Solve.hs131
1 files changed, 88 insertions, 43 deletions
diff --git a/src/HSVIS/Typecheck/Solve.hs b/src/HSVIS/Typecheck/Solve.hs
index 184937c..5f51abe 100644
--- a/src/HSVIS/Typecheck/Solve.hs
+++ b/src/HSVIS/Typecheck/Solve.hs
@@ -1,10 +1,15 @@
+{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE ScopedTypeVariables #-}
-module HSVIS.Typecheck.Solve where
+{-# LANGUAGE TypeApplications #-}
+module HSVIS.Typecheck.Solve (
+ solveConstraints,
+ UnifyErr(..),
+) where
import Control.Monad (guard, (>=>))
-import Data.Bifunctor (second)
+import Data.Bifunctor (Bifunctor(..))
import Data.Foldable (toList, foldl')
-import Data.List (sort)
+import Data.List (sortBy, minimumBy, groupBy)
import Data.Ord (comparing)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
@@ -12,26 +17,53 @@ import qualified Data.Map.Strict as Map
import Debug.Trace
import Data.Bag
-import Data.List (minimumBy)
+import HSVIS.Diagnostic (Range(..))
+import HSVIS.Pretty
+import Data.Function (on)
data UnifyErr v t
- = UEUnequal t t
- | UERecursive v t
+ = UEUnequal t t Range
+ | UERecursive v t Range
deriving (Show)
+data Constr a b = Constr a b Range
+ deriving (Show)
+
+instance Bifunctor Constr where
+ bimap f g (Constr x y r) = Constr (f x) (g y) r
+
+-- right-hand side of a constraint
+data RConstr b = RConstr b Range
+ deriving (Show, Functor)
+
+splitConstr :: Constr a b -> (a, RConstr b)
+splitConstr (Constr x y r) = (x, RConstr y r)
+
+unsplitConstr :: a -> RConstr b -> Constr a b
+unsplitConstr x (RConstr y r) = Constr x y r
+
+constrUnequal :: Constr t t -> UnifyErr v t
+constrUnequal (Constr x y r) = UEUnequal x y r
+
+constrRecursive :: Constr v t -> UnifyErr v t
+constrRecursive (Constr x y r) = UERecursive x y r
+
+rconType :: RConstr b -> b
+rconType (RConstr t _) = t
+
-- | Returns a pair of:
-- 1. A set of unification errors;
-- 2. An assignment of the variables that had any constraints on them.
-- The producedure was successful if the set of errors is empty. Note that
-- unconstrained variables do not appear in the output.
solveConstraints
- :: forall v t. (Ord v, Ord t, Show v, Show t)
+ :: forall v t. (Ord v, Ord t, Show v, Pretty t)
-- | reduce: take two types and unify them, resulting in:
-- 1. A bag of resulting constraints on variables;
-- 2. A bag of errors: pairs of two types that are provably distinct but
-- need to be equal for the input types to unify.
- => (t -> t -> (Bag (v, t), Bag (t, t)))
+ => (t -> t -> Range -> (Bag (v, t, Range), Bag (t, t, Range)))
-- | Free variables in a type
-> (t -> Bag v)
-- | \v repl term -> Substitute v by repl in term
@@ -41,63 +73,76 @@ solveConstraints
-- | Some kind of size measure on types
-> (t -> Int)
-- | Equality constraints to solve
- -> [(t, t)]
+ -> [(t, t, Range)]
-> (Map v t, Bag (UnifyErr v t))
-solveConstraints reduce frees subst detect size = \cs ->
- let (vcs, errs) = foldMap (uncurry reduce) cs
- asg = Map.fromListWith (<>) (map (second pure) (toList vcs))
+solveConstraints reduce frees subst detect size = \tupcs ->
+ let cs = map (uncurry3 Constr) tupcs :: [Constr t t]
+ (vcs, errs) = foldMap reduce' cs
+ asg = Map.fromListWith (<>) (map (second pure . splitConstr) (toList vcs))
(errs', asg') = loop asg []
- errs'' = fmap (uncurry UEUnequal) errs <> errs'
- in trace ("[solver] Solving:" ++ concat ["\n- " ++ show a ++ " == " ++ show b | (a, b) <- cs]) $
+ errs'' = fmap constrUnequal errs <> errs'
+ in trace ("[solver] Solving:" ++ concat ["\n- " ++ pretty a ++ " == " ++ pretty b ++ " {" ++ pretty r ++ "}" | Constr a b r <- cs]) $
trace ("[solver] Result: (with " ++ show (length errs'') ++ " errors)" ++
- concat ["\n- " ++ show v ++ " = " ++ show t | (v, t) <- Map.assocs asg'])
+ concat ["\n- " ++ show v ++ " = " ++ pretty t | (v, t) <- Map.assocs asg'])
(asg', errs'')
where
- loop :: Map v (Bag t) -> [(v, t)] -> (Bag (UnifyErr v t), Map v t)
+ reduce' :: Constr t t -> (Bag (Constr v t), Bag (Constr t t))
+ reduce' (Constr t1 t2 r) = bimap (fmap (uncurry3 Constr)) (fmap (uncurry3 Constr)) $ reduce t1 t2 r
+
+ loop :: Map v (Bag (RConstr t)) -> [(v, t)] -> (Bag (UnifyErr v t), Map v t)
loop m eqlog = do
- traceM $ "[solver] Step:" ++ concat ["\n- " ++ show v ++ " == " ++ show t | (v, t) <- Map.assocs m]
+ traceM $ "[solver] Step:" ++ concat
+ [case toList rhss of
+ [] -> "\n- " ++ show v ++ " <no RHSs>"
+ RConstr t r : rest ->
+ "\n- " ++ show v ++ " == " ++ pretty t ++ " {" ++ pretty r ++ "}" ++
+ concat ["\n " ++ replicate (length (show v)) ' ' ++ " == " ++ pretty t' ++ " {" ++ pretty r' ++ "}"
+ | RConstr t' r' <- rest]
+ | (v, rhss) <- Map.assocs m]
+
m' <- Map.traverseWithKey
- (\v ts ->
- let ts' = bagFromList (dedup (toList ts))
+ (\v rhss ->
+ let rhss' = bagFromList (dedupRCs (toList rhss))
-- filter out recursive equations
- (recs, nonrecs) = bagPartition (\t -> if v `elem` frees t then Just t else Nothing) ts'
+ (recs, nonrecs) = bagPartition (\c@(RConstr t _) -> if v `elem` frees t then Just c else Nothing) rhss'
-- filter out trivial equations (v = v)
- (_, nonrecs') = bagPartition (detect >=> guard . (== v)) nonrecs
- in (UERecursive v <$> recs, nonrecs'))
+ (_, nonrecs') = bagPartition (detect . rconType >=> guard . (== v)) nonrecs
+ in (constrRecursive . unsplitConstr v <$> recs, nonrecs'))
m
- let msmallestvar = -- var with its smallest RHS, if such a var exists
- minimumByMay (comparing (size . snd))
- . map (second (minimumBy (comparing size)))
+ let msmallestvar :: Maybe (v, RConstr t) -- var with its smallest RHS, if such a var exists
+ msmallestvar =
+ minimumByMay (comparing (size . rconType . snd))
+ . map (second (minimumBy (comparing (size . rconType))))
. filter (not . null . snd)
$ Map.assocs m'
case msmallestvar of
Nothing -> return $ applyLog eqlog mempty
- Just (var, smallrhs) -> do
- let (_, otherrhss) = bagPartition (guard . (== smallrhs)) (m' Map.! var)
- let (newcs, errs) = foldMap (reduce smallrhs) (dedup (toList otherrhss))
- (fmap (uncurry UEUnequal) errs, ()) -- write the errors
+ Just (var, RConstr smallrhs _) -> do
+ let (_, otherrhss) = bagPartition (guard . (== smallrhs) . rconType) (m' Map.! var)
+ let (newcs, errs) = foldMap (reduce' . unsplitConstr smallrhs) (dedupRCs (toList otherrhss))
+ (fmap constrUnequal errs, ()) -- write the errors
let m'' = Map.unionWith (<>)
- (Map.map (fmap (subst var smallrhs)) (Map.delete var m'))
- (Map.fromListWith (<>) (map (second pure) (toList newcs)))
+ (Map.map (fmap @Bag (fmap @RConstr (subst var smallrhs)))
+ (Map.delete var m'))
+ (Map.fromListWith (<>) (map (second pure . splitConstr) (toList newcs)))
loop m'' ((var, smallrhs) : eqlog)
applyLog :: [(v, t)] -> Map v t -> Map v t
applyLog ((v, t) : l) m = applyLog l $ Map.insert v t (Map.map (subst v t) m)
applyLog [] m = m
- dedup :: Ord t => [t] -> [t]
- dedup = uniq . sort
+ -- If there are multiple sources for the same cosntraint, only one of them is kept.
+ dedupRCs :: Ord t => [RConstr t] -> [RConstr t]
+ dedupRCs = map head . groupBy ((==) `on` rconType) . sortBy (comparing rconType)
- uniq :: Eq a => [a] -> [a]
- uniq (x:y:xs) | x == y = uniq (x : xs)
- | otherwise = x : uniq (y : xs)
- uniq l = l
+minimumByMay :: Foldable t' => (a -> a -> Ordering) -> t' a -> Maybe a
+minimumByMay cmp = foldl' min' Nothing
+ where min' mx y = Just $! case mx of
+ Nothing -> y
+ Just x | GT <- cmp x y -> y
+ | otherwise -> x
- minimumByMay :: Foldable t' => (a -> a -> Ordering) -> t' a -> Maybe a
- minimumByMay cmp = foldl' min' Nothing
- where min' mx y = Just $! case mx of
- Nothing -> y
- Just x | GT <- cmp x y -> y
- | otherwise -> x
+uncurry3 :: (a -> b -> c -> d) -> (a, b, c) -> d
+uncurry3 f (x, y, z) = f x y z