diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-03-14 23:21:53 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-03-14 23:21:53 +0100 |
commit | e7bed242ba52e6d3233928f2c6189e701cfa5e4c (patch) | |
tree | 4bdda2b7bc702c87d97f89946362e6b719126831 /src/HSVIS/Typecheck/Solve.hs | |
parent | e8f09ff3f9d40922238d646c8fbcbacf9cfdfb62 (diff) |
Some typechecker work
Diffstat (limited to 'src/HSVIS/Typecheck/Solve.hs')
-rw-r--r-- | src/HSVIS/Typecheck/Solve.hs | 103 |
1 files changed, 103 insertions, 0 deletions
diff --git a/src/HSVIS/Typecheck/Solve.hs b/src/HSVIS/Typecheck/Solve.hs new file mode 100644 index 0000000..184937c --- /dev/null +++ b/src/HSVIS/Typecheck/Solve.hs @@ -0,0 +1,103 @@ +{-# LANGUAGE ScopedTypeVariables #-} +module HSVIS.Typecheck.Solve where + +import Control.Monad (guard, (>=>)) +import Data.Bifunctor (second) +import Data.Foldable (toList, foldl') +import Data.List (sort) +import Data.Ord (comparing) +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as Map + +import Debug.Trace + +import Data.Bag +import Data.List (minimumBy) + + +data UnifyErr v t + = UEUnequal t t + | UERecursive v t + deriving (Show) + +-- | Returns a pair of: +-- 1. A set of unification errors; +-- 2. An assignment of the variables that had any constraints on them. +-- The producedure was successful if the set of errors is empty. Note that +-- unconstrained variables do not appear in the output. +solveConstraints + :: forall v t. (Ord v, Ord t, Show v, Show t) + -- | reduce: take two types and unify them, resulting in: + -- 1. A bag of resulting constraints on variables; + -- 2. A bag of errors: pairs of two types that are provably distinct but + -- need to be equal for the input types to unify. + => (t -> t -> (Bag (v, t), Bag (t, t))) + -- | Free variables in a type + -> (t -> Bag v) + -- | \v repl term -> Substitute v by repl in term + -> (v -> t -> t -> t) + -- | Detect bare-variable types + -> (t -> Maybe v) + -- | Some kind of size measure on types + -> (t -> Int) + -- | Equality constraints to solve + -> [(t, t)] + -> (Map v t, Bag (UnifyErr v t)) +solveConstraints reduce frees subst detect size = \cs -> + let (vcs, errs) = foldMap (uncurry reduce) cs + asg = Map.fromListWith (<>) (map (second pure) (toList vcs)) + (errs', asg') = loop asg [] + errs'' = fmap (uncurry UEUnequal) errs <> errs' + in trace ("[solver] Solving:" ++ concat ["\n- " ++ show a ++ " == " ++ show b | (a, b) <- cs]) $ + trace ("[solver] Result: (with " ++ show (length errs'') ++ " errors)" ++ + concat ["\n- " ++ show v ++ " = " ++ show t | (v, t) <- Map.assocs asg']) + (asg', errs'') + where + loop :: Map v (Bag t) -> [(v, t)] -> (Bag (UnifyErr v t), Map v t) + loop m eqlog = do + traceM $ "[solver] Step:" ++ concat ["\n- " ++ show v ++ " == " ++ show t | (v, t) <- Map.assocs m] + m' <- Map.traverseWithKey + (\v ts -> + let ts' = bagFromList (dedup (toList ts)) + -- filter out recursive equations + (recs, nonrecs) = bagPartition (\t -> if v `elem` frees t then Just t else Nothing) ts' + -- filter out trivial equations (v = v) + (_, nonrecs') = bagPartition (detect >=> guard . (== v)) nonrecs + in (UERecursive v <$> recs, nonrecs')) + m + + let msmallestvar = -- var with its smallest RHS, if such a var exists + minimumByMay (comparing (size . snd)) + . map (second (minimumBy (comparing size))) + . filter (not . null . snd) + $ Map.assocs m' + + case msmallestvar of + Nothing -> return $ applyLog eqlog mempty + Just (var, smallrhs) -> do + let (_, otherrhss) = bagPartition (guard . (== smallrhs)) (m' Map.! var) + let (newcs, errs) = foldMap (reduce smallrhs) (dedup (toList otherrhss)) + (fmap (uncurry UEUnequal) errs, ()) -- write the errors + let m'' = Map.unionWith (<>) + (Map.map (fmap (subst var smallrhs)) (Map.delete var m')) + (Map.fromListWith (<>) (map (second pure) (toList newcs))) + loop m'' ((var, smallrhs) : eqlog) + + applyLog :: [(v, t)] -> Map v t -> Map v t + applyLog ((v, t) : l) m = applyLog l $ Map.insert v t (Map.map (subst v t) m) + applyLog [] m = m + + dedup :: Ord t => [t] -> [t] + dedup = uniq . sort + + uniq :: Eq a => [a] -> [a] + uniq (x:y:xs) | x == y = uniq (x : xs) + | otherwise = x : uniq (y : xs) + uniq l = l + + minimumByMay :: Foldable t' => (a -> a -> Ordering) -> t' a -> Maybe a + minimumByMay cmp = foldl' min' Nothing + where min' mx y = Just $! case mx of + Nothing -> y + Just x | GT <- cmp x y -> y + | otherwise -> x |