aboutsummaryrefslogtreecommitdiff
path: root/src/HSVIS/Typecheck
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-03-14 23:21:53 +0100
committerTom Smeding <tom@tomsmeding.com>2024-03-14 23:21:53 +0100
commite7bed242ba52e6d3233928f2c6189e701cfa5e4c (patch)
tree4bdda2b7bc702c87d97f89946362e6b719126831 /src/HSVIS/Typecheck
parente8f09ff3f9d40922238d646c8fbcbacf9cfdfb62 (diff)
Some typechecker work
Diffstat (limited to 'src/HSVIS/Typecheck')
-rw-r--r--src/HSVIS/Typecheck/Solve.hs103
1 files changed, 103 insertions, 0 deletions
diff --git a/src/HSVIS/Typecheck/Solve.hs b/src/HSVIS/Typecheck/Solve.hs
new file mode 100644
index 0000000..184937c
--- /dev/null
+++ b/src/HSVIS/Typecheck/Solve.hs
@@ -0,0 +1,103 @@
+{-# LANGUAGE ScopedTypeVariables #-}
+module HSVIS.Typecheck.Solve where
+
+import Control.Monad (guard, (>=>))
+import Data.Bifunctor (second)
+import Data.Foldable (toList, foldl')
+import Data.List (sort)
+import Data.Ord (comparing)
+import Data.Map.Strict (Map)
+import qualified Data.Map.Strict as Map
+
+import Debug.Trace
+
+import Data.Bag
+import Data.List (minimumBy)
+
+
+data UnifyErr v t
+ = UEUnequal t t
+ | UERecursive v t
+ deriving (Show)
+
+-- | Returns a pair of:
+-- 1. A set of unification errors;
+-- 2. An assignment of the variables that had any constraints on them.
+-- The producedure was successful if the set of errors is empty. Note that
+-- unconstrained variables do not appear in the output.
+solveConstraints
+ :: forall v t. (Ord v, Ord t, Show v, Show t)
+ -- | reduce: take two types and unify them, resulting in:
+ -- 1. A bag of resulting constraints on variables;
+ -- 2. A bag of errors: pairs of two types that are provably distinct but
+ -- need to be equal for the input types to unify.
+ => (t -> t -> (Bag (v, t), Bag (t, t)))
+ -- | Free variables in a type
+ -> (t -> Bag v)
+ -- | \v repl term -> Substitute v by repl in term
+ -> (v -> t -> t -> t)
+ -- | Detect bare-variable types
+ -> (t -> Maybe v)
+ -- | Some kind of size measure on types
+ -> (t -> Int)
+ -- | Equality constraints to solve
+ -> [(t, t)]
+ -> (Map v t, Bag (UnifyErr v t))
+solveConstraints reduce frees subst detect size = \cs ->
+ let (vcs, errs) = foldMap (uncurry reduce) cs
+ asg = Map.fromListWith (<>) (map (second pure) (toList vcs))
+ (errs', asg') = loop asg []
+ errs'' = fmap (uncurry UEUnequal) errs <> errs'
+ in trace ("[solver] Solving:" ++ concat ["\n- " ++ show a ++ " == " ++ show b | (a, b) <- cs]) $
+ trace ("[solver] Result: (with " ++ show (length errs'') ++ " errors)" ++
+ concat ["\n- " ++ show v ++ " = " ++ show t | (v, t) <- Map.assocs asg'])
+ (asg', errs'')
+ where
+ loop :: Map v (Bag t) -> [(v, t)] -> (Bag (UnifyErr v t), Map v t)
+ loop m eqlog = do
+ traceM $ "[solver] Step:" ++ concat ["\n- " ++ show v ++ " == " ++ show t | (v, t) <- Map.assocs m]
+ m' <- Map.traverseWithKey
+ (\v ts ->
+ let ts' = bagFromList (dedup (toList ts))
+ -- filter out recursive equations
+ (recs, nonrecs) = bagPartition (\t -> if v `elem` frees t then Just t else Nothing) ts'
+ -- filter out trivial equations (v = v)
+ (_, nonrecs') = bagPartition (detect >=> guard . (== v)) nonrecs
+ in (UERecursive v <$> recs, nonrecs'))
+ m
+
+ let msmallestvar = -- var with its smallest RHS, if such a var exists
+ minimumByMay (comparing (size . snd))
+ . map (second (minimumBy (comparing size)))
+ . filter (not . null . snd)
+ $ Map.assocs m'
+
+ case msmallestvar of
+ Nothing -> return $ applyLog eqlog mempty
+ Just (var, smallrhs) -> do
+ let (_, otherrhss) = bagPartition (guard . (== smallrhs)) (m' Map.! var)
+ let (newcs, errs) = foldMap (reduce smallrhs) (dedup (toList otherrhss))
+ (fmap (uncurry UEUnequal) errs, ()) -- write the errors
+ let m'' = Map.unionWith (<>)
+ (Map.map (fmap (subst var smallrhs)) (Map.delete var m'))
+ (Map.fromListWith (<>) (map (second pure) (toList newcs)))
+ loop m'' ((var, smallrhs) : eqlog)
+
+ applyLog :: [(v, t)] -> Map v t -> Map v t
+ applyLog ((v, t) : l) m = applyLog l $ Map.insert v t (Map.map (subst v t) m)
+ applyLog [] m = m
+
+ dedup :: Ord t => [t] -> [t]
+ dedup = uniq . sort
+
+ uniq :: Eq a => [a] -> [a]
+ uniq (x:y:xs) | x == y = uniq (x : xs)
+ | otherwise = x : uniq (y : xs)
+ uniq l = l
+
+ minimumByMay :: Foldable t' => (a -> a -> Ordering) -> t' a -> Maybe a
+ minimumByMay cmp = foldl' min' Nothing
+ where min' mx y = Just $! case mx of
+ Nothing -> y
+ Just x | GT <- cmp x y -> y
+ | otherwise -> x