From 5a0ce21e12e765125ad8068e919cf97b70df8257 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 28 Aug 2024 16:10:58 +0200 Subject: Implement sorting of floated expressions --- src/Data/Expr/SharingRecovery.hs | 172 +++++++++++++++++++++++++++++---------- 1 file changed, 130 insertions(+), 42 deletions(-) (limited to 'src/Data/Expr/SharingRecovery.hs') diff --git a/src/Data/Expr/SharingRecovery.hs b/src/Data/Expr/SharingRecovery.hs index cdb64eb..f9d27e6 100644 --- a/src/Data/Expr/SharingRecovery.hs +++ b/src/Data/Expr/SharingRecovery.hs @@ -17,7 +17,7 @@ module Data.Expr.SharingRecovery where import Control.Applicative ((<|>)) import Control.Monad.Trans.State.Strict -import Data.Bifunctor (second) +import Data.Bifunctor (first, second) import Data.Char (chr, ord) import Data.Functor.Const import Data.Functor.Identity @@ -25,12 +25,17 @@ import Data.Functor.Product import Data.Hashable import Data.HashMap.Strict (HashMap) import qualified Data.HashMap.Strict as HM +import Data.List (sortBy, intersperse) +import Data.Maybe (fromMaybe) +import Data.Ord (comparing) import Data.Some import Data.Type.Equality import GHC.StableName import Numeric.Natural import Unsafe.Coerce (unsafeCoerce) +-- import Debug.Trace + import Data.StableName.Extra @@ -41,6 +46,16 @@ import Data.StableName.Extra -- is a good opportunity to try to do better. +withMoreState :: Functor m => b -> StateT (s, b) m a -> StateT s m (a, b) +withMoreState b0 (StateT f) = + StateT $ \s -> (\(x, (s2, b)) -> ((x, b), s2)) <$> f (s, b0) + +withLessState :: Functor m => (s -> (s', b)) -> (s' -> b -> s) -> StateT s' m a -> StateT s m a +withLessState split restore (StateT f) = + StateT $ \s -> let (s', b) = split s + in second (flip restore b) <$> f s' + + class Functor1 f where fmap1 :: (forall b. g b -> h b) -> f g a -> f h a @@ -91,49 +106,85 @@ instance Eq (SomeNameFor typ f) where instance Hashable (SomeNameFor typ f) where hashWithSalt salt (SomeNameFor name) = hashWithSalt salt name --- | The number of times a particular name is visited in a preorder traversal --- of the PHOAS expression, excluding children of nodes upon second or later --- visit. That is to say: only the nodes that are visited in a preorder --- traversal that skips repeated subtrees, are counted. -type OccMap typ f = HashMap (SomeNameFor typ f) Natural +prettyPExpr :: Traversable1 f => Int -> PExpr typ f t -> ShowS +prettyPExpr d = \case + PStub (NameFor name) _ -> showString (showStableName name) + POp (NameFor name) _ args -> + let (argslist, _) = traverse1 (\arg -> ([Some arg], Const ())) args + argslist' = map (\(Some arg) -> prettyPExpr 0 arg) argslist + in showParen (d > 10) $ + showString ("<" ++ showStableName name ++ ">(") + . foldr (.) id (intersperse (showString ", ") argslist') + . showString ")" + PLam (NameFor name) _ _ (Tag tag) body -> + showParen (d > 0) $ + showString ("λ" ++ showStableName name ++ " x" ++ show tag ++ ". ") . prettyPExpr 0 body + PVar _ (Tag tag) -> showString ("x" ++ show tag) + +-- | For each name: +-- +-- 1. The number of times the name is visited in a preorder traversal of the +-- PHOAS expression, excluding children of nodes upon second or later visit. +-- That is to say: only the nodes that are visited in a preorder traversal +-- that skips repeated subtrees, are counted. +-- 2. The height of the expression indicated by the name. +-- +-- Missing names have not been seen yet, and have unknown height. +type OccMap typ f = HashMap (SomeNameFor typ f) (Natural, Natural) pruneExpr :: Traversable1 f => (forall v. PHOASExpr typ v f t) -> (OccMap typ f, PExpr typ f t) pruneExpr term = - let (term', (_, mp)) = runState (pruneExpr' term) (0, mempty) + let ((term', _), (_, mp)) = runState (pruneExpr' term) (0, mempty) in (mp, term') -pruneExpr' :: Traversable1 f => PHOASExpr typ Tag f t -> State (Natural, OccMap typ f) (PExpr typ f t) +-- | Returns pruned expression with its height. +pruneExpr' :: Traversable1 f => PHOASExpr typ Tag f t -> State (Natural, OccMap typ f) (PExpr typ f t, Natural) pruneExpr' = \case orig@(PHOASOp ty args) -> do let name = makeStableName' orig - seenBefore <- checkVisited name - if seenBefore - then pure $ PStub (NameFor name) ty - else POp (NameFor name) ty <$> traverse1 pruneExpr' args + mheight <- gets (fmap snd . HM.lookup (SomeNameFor (NameFor name)) . snd) + case mheight of + -- already visited + Just height -> do + modify (second (HM.adjust (first (+1)) (SomeNameFor (NameFor name)))) + pure (PStub (NameFor name) ty, height) + -- first visit + Nothing -> do + -- Traverse the arguments, collecting the maximum height in an + -- additional piece of state. + (args', maxhei) <- + withMoreState 0 $ + traverse1 (\arg -> do + (arg', hei) <- withLessState id (,) (pruneExpr' arg) + modify (second (hei `max`)) + return arg') + args + -- Record this node + modify (second (HM.insert (SomeNameFor (NameFor name)) (1, 1 + maxhei))) + pure (POp (NameFor name) ty args', 1 + maxhei) orig@(PHOASLam tyf tyarg f) -> do let name = makeStableName' orig - seenBefore <- checkVisited name - if seenBefore - then pure $ PStub (NameFor name) tyf - else do - tag <- state (\(i, mp) -> (Tag i, (i + 1, mp))) + mheight <- gets (fmap snd . HM.lookup (SomeNameFor (NameFor name)) . snd) + case mheight of + -- already visited + Just height -> do + modify (second (HM.adjust (first (+1)) (SomeNameFor (NameFor name)))) + pure (PStub (NameFor name) tyf, height) + -- first visit + Nothing -> do + tag <- Tag <$> gets fst + modify (first (+1)) let body = f tag - PLam (NameFor name) tyf tyarg tag <$> pruneExpr' body + (body', bodyhei) <- pruneExpr' body + modify (second (HM.insert (SomeNameFor (NameFor name)) (1, 1 + bodyhei))) + pure (PLam (NameFor name) tyf tyarg tag body', 1 + bodyhei) - PHOASVar ty tag -> pure $ PVar ty tag - where - checkVisited name = do - occmap <- gets snd - let (seenBefore, occmap') = - HM.alterF (\case Nothing -> (False, Just 1) - Just n -> (True, Just (n + 1))) - (SomeNameFor (NameFor name)) - occmap - modify (second (const occmap')) - return seenBefore + PHOASVar ty tag -> pure (PVar ty tag, 1) +-- TODO: Replace "lift" with "float" + -- | Lifted expression: a bunch of to-be let bound expressions on top of an -- LExpr'. Because LExpr' is really just PExpr with the recursive positions -- replaced by LExpr, LExpr should be seen as PExpr with a bunch of to-be let @@ -145,12 +196,35 @@ data LExpr' typ f t where -- TODO: this could be an instantiation of (a general LLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> LExpr typ f b -> LExpr' typ f (a -> b) LVar :: typ a -> Tag a -> LExpr' typ f a +prettyLExpr :: Traversable1 f => Int -> LExpr typ f t -> ShowS +prettyLExpr d (LExpr [] e) = prettyLExpr' d e +prettyLExpr d (LExpr lifted e) = + showString "[" + . foldr (.) id (intersperse (showString ", ") (map (\(Some e') -> prettyLExpr 0 e') lifted)) + . showString "] " . prettyLExpr' d e + +prettyLExpr' :: Traversable1 f => Int -> LExpr' typ f t -> ShowS +prettyLExpr' d = \case + LStub (NameFor name) _ -> showString (showStableName name) + LOp (NameFor name) _ args -> + let (argslist, _) = traverse1 (\arg -> ([Some arg], Const ())) args + argslist' = map (\(Some arg) -> prettyLExpr 0 arg) argslist + in showParen (d > 10) $ + showString ("<" ++ showStableName name ++ ">(") + . foldr (.) id (intersperse (showString ", ") argslist') + . showString ")" + LLam (NameFor name) _ _ (Tag tag) body -> + showParen (d > 0) $ + showString ("λ" ++ showStableName name ++ " x" ++ show tag ++ ". ") . prettyLExpr 0 body + LVar _ (Tag tag) -> showString ("x" ++ show tag) + liftExpr :: Traversable1 f => OccMap typ f -> PExpr typ f t -> LExpr typ f t liftExpr totals term = snd (liftExpr' totals term) newtype FoundMap typ f = FoundMap - (HashMap (SomeNameFor typ f) (Natural -- how many times seen - ,Maybe (Some (LExpr typ f)))) -- the lifted subterm (once seen) + (HashMap (SomeNameFor typ f) + (Natural -- how many times seen + ,Maybe (Some (LExpr typ f), Natural))) -- the lifted subterm with its height (once seen) instance Semigroup (FoundMap typ f) where FoundMap m1 <> FoundMap m2 = FoundMap $ @@ -161,10 +235,13 @@ instance Monoid (FoundMap typ f) where liftExpr' :: Traversable1 f => OccMap typ f -> PExpr typ f t -> (FoundMap typ f, LExpr typ f t) liftExpr' _totals (PStub name ty) = - (FoundMap $ HM.singleton (SomeNameFor name) (1, Nothing) -- Just (Some (LExpr [] (LStub name))) + -- trace ("Found stub: " ++ (case name of NameFor n -> showStableName n)) $ + (FoundMap $ HM.singleton (SomeNameFor name) (1, Nothing) ,LExpr [] (LStub name ty)) -liftExpr' _totals (PVar ty tag) = (mempty, LExpr [] (LVar ty tag)) +liftExpr' _totals (PVar ty tag) = + -- trace ("Found var: " ++ show tag) $ + (mempty, LExpr [] (LVar ty tag)) liftExpr' totals term = let (FoundMap foundmap, name, termty, term') = case term of @@ -178,19 +255,24 @@ liftExpr' totals term = -- TODO: perhaps this HM.toList together with the foldr HM.delete can be a single traversal of the HashMap saturated = [case mterm of Just t -> (nm, t) - Nothing -> error "Name saturated but no term found" + Nothing -> case nm of + SomeNameFor (NameFor n) -> + error $ "Name saturated (count=" ++ show count ++ ", totalcount=" ++ show totalcount ++ ") but no term found: " ++ showStableName n | (nm, (count, mterm)) <- HM.toList foundmap - , count == HM.findWithDefault 0 nm totals] + , let totalcount = fromMaybe 0 (fst <$> HM.lookup nm totals) + , count == totalcount] foundmap' = foldr HM.delete foundmap (map fst saturated) - lterm = LExpr (map snd saturated) term' + lterm = LExpr (map fst (sortBy (comparing snd) (map snd saturated))) term' - in case HM.findWithDefault 0 (SomeNameFor name) totals of - 1 -> (FoundMap foundmap', lterm) - tot | tot > 1 -> (FoundMap (HM.insert (SomeNameFor name) (1, Just (Some lterm)) foundmap') - ,LExpr [] (LStub name termty)) - | otherwise -> error "Term does not exist, yet we have it in hand" + in case HM.findWithDefault (0, undefined) (SomeNameFor name) totals of + (1, _) -> (FoundMap foundmap', lterm) + (tot, height) + | tot > 1 -> -- trace ("Inserting " ++ (case name of NameFor n -> showStableName n) ++ " into foundmap") $ + (FoundMap (HM.insert (SomeNameFor name) (1, Just (Some lterm, height)) foundmap') + ,LExpr [] (LStub name termty)) + | otherwise -> error "Term does not exist, yet we have it in hand" -- | Untyped De Bruijn expression. No more names: there are lets now, and @@ -334,4 +416,10 @@ retypeExpr' env (UBVar ty idx) = sharingRecovery :: (Traversable1 f, TestEquality typ) => (forall v. PHOASExpr typ v f t) -> BExpr typ '[] f t -sharingRecovery e = retypeExpr $ lowerExpr $ uncurry liftExpr $ pruneExpr e +sharingRecovery e = + let (occmap, pexpr) = pruneExpr e + lexpr = liftExpr occmap pexpr + ubexpr = lowerExpr lexpr + in -- trace ("PExpr: " ++ prettyPExpr 0 pexpr "") $ + -- trace ("LExpr: " ++ prettyLExpr 0 lexpr "") $ + retypeExpr ubexpr -- cgit v1.2.3-70-g09d2