diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-08-01 22:04:52 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-08-01 22:04:52 +0200 |
commit | 250e3beae7a961fc740f775a563c303b4cc390fe (patch) | |
tree | ccbb8a090cdb082d86c0651935eeb986e2cddcea /src/Data/Expr |
Initial
Diffstat (limited to 'src/Data/Expr')
-rw-r--r-- | src/Data/Expr/SharingRecovery.hs | 157 |
1 files changed, 157 insertions, 0 deletions
diff --git a/src/Data/Expr/SharingRecovery.hs b/src/Data/Expr/SharingRecovery.hs new file mode 100644 index 0000000..118df1c --- /dev/null +++ b/src/Data/Expr/SharingRecovery.hs @@ -0,0 +1,157 @@ +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeOperators #-} +module Data.Expr.SharingRecovery where + +import Control.Applicative ((<|>)) +import Control.Monad.Trans.State.Strict +import Data.Bifunctor (second) +import Data.GADT.Compare +import Data.Hashable +import Data.HashMap.Strict (HashMap) +import qualified Data.HashMap.Strict as HM +import Data.Some +import Data.Type.Equality +import GHC.StableName +import Numeric.Natural +import Unsafe.Coerce (unsafeCoerce) + +import Data.StableName.Extra + + +class Functor1 f where + fmap1 :: (forall b. g b -> h b) -> f g a -> f h a + +class Functor1 f => Traversable1 f where + traverse1 :: Applicative m => (forall b. g b -> m (h b)) -> f g a -> m (f h a) + +-- | Expression in parametric higher-order abstract syntax form +data PHOASExpr v f t where + PHOASOp :: f (PHOASExpr v f) t -> PHOASExpr v f t + PHOASLam :: (PHOASExpr v f a -> PHOASExpr v f b) -> PHOASExpr v f (a -> b) + PHOASVar :: v t -> PHOASExpr v f t + +newtype Tag t = Tag Natural + deriving (Show, Eq) + +newtype NameFor f t = NameFor (StableName (PHOASExpr Tag f t)) + deriving (Eq) + deriving (Hashable) via (StableName (f (PHOASExpr Tag f) t)) + +instance GEq (NameFor f) where + geq (NameFor n1) (NameFor n2) + | eqStableName n1 n2 = Just unsafeCoerceRefl + | otherwise = Nothing + where + unsafeCoerceRefl :: a :~: b -- restricted version of unsafeCoerce that only allows punting proofs + unsafeCoerceRefl = unsafeCoerce Refl + +-- | Pruned expression +data PExpr f t where + PStub :: NameFor f t -> PExpr f t + POp :: NameFor f t -> f (PExpr f) t -> PExpr f t + PLam :: NameFor f (a -> b) -> Tag a -> PExpr f b -> PExpr f (a -> b) + PVar :: Tag a -> PExpr f a + +data SomeNameFor f = forall t. SomeNameFor {-# UNPACK #-} !(NameFor f t) + +instance Eq (SomeNameFor f) where + SomeNameFor (NameFor n1) == SomeNameFor (NameFor n2) = eqStableName n1 n2 + +instance Hashable (SomeNameFor f) where + hashWithSalt salt (SomeNameFor name) = hashWithSalt salt name + +type OccMap f = HashMap (SomeNameFor f) Natural + +pruneExpr :: Traversable1 f => (forall v. PHOASExpr v f t) -> (OccMap f, PExpr f t) +pruneExpr term = + let (term', (_, mp)) = runState (pruneExpr' term) (0, mempty) + in (mp, term') + +pruneExpr' :: Traversable1 f => PHOASExpr Tag f t -> State (Natural, OccMap f) (PExpr f t) +pruneExpr' orig@(PHOASOp args) = do + let name = makeStableName' orig + occmap <- gets snd + let (seenBefore, occmap') = + HM.alterF (\case Nothing -> (False, Just 1) + Just n -> (True, Just (n + 1))) + (SomeNameFor (NameFor name)) + occmap + modify (second (const occmap')) + if seenBefore + then pure $ PStub (NameFor name) + else POp (NameFor name) <$> traverse1 pruneExpr' args + +pruneExpr' orig@(PHOASLam f) = do + let name = makeStableName' orig + tag <- state (\(i, mp) -> (Tag i, (i + 1, mp))) + let body = f (PHOASVar tag) + PLam (NameFor name) tag <$> pruneExpr' body + +pruneExpr' (PHOASVar tag) = pure $ PVar tag + + +-- | Lifted expression: a bunch of to-be let bound expressions on top of an LExpr' +data LExpr f t = LExpr [Some (LExpr f)] (LExpr' f t) +data LExpr' f t where -- TODO: this could be an instantiation of (a generalisation of) PExpr + LStub :: NameFor f t -> LExpr' f t + LOp :: NameFor f t -> f (LExpr f) t -> LExpr' f t + LLam :: NameFor f (a -> b) -> Tag a -> LExpr f b -> LExpr' f (a -> b) + LVar :: Tag a -> LExpr' f a + +liftExpr :: Traversable1 f => OccMap f -> PExpr f t -> LExpr f t +liftExpr totals term = + let (_, e) = liftExpr' totals term + in e + +newtype FoundMap f = FoundMap + (HashMap (SomeNameFor f) (Natural -- how many times seen + ,Maybe (Some (LExpr f)))) -- the lifted subterm (once seen) + +instance Semigroup (FoundMap f) where + FoundMap m1 <> FoundMap m2 = FoundMap $ + HM.unionWith (\(n1, me1) (n2, me2) -> (n1 + n2, me1 <|> me2)) m1 m2 + +instance Monoid (FoundMap f) where + mempty = FoundMap HM.empty + +liftExpr' :: Traversable1 f => OccMap f -> PExpr f t -> (FoundMap f, LExpr f t) +liftExpr' _totals (PStub name) = + (FoundMap $ HM.singleton (SomeNameFor name) (1, Just (Some (LExpr [] (LStub name)))) + ,LExpr [] (LStub name)) + +liftExpr' _totals (PVar tag) = (mempty, LExpr [] (LVar tag)) + +liftExpr' totals term = + let (FoundMap foundmap, name, term') = case term of + POp n args -> let (fm, args') = traverse1 (liftExpr' totals) args + in (fm, n, LOp n args') + PLam n tag body -> let (fm, body') = liftExpr' totals body + in (fm, n, LLam n tag body') + + saturated = [case mterm of + Just t -> (nm, t) + Nothing -> error "Name saturated but no term found" + | (nm, (count, mterm)) <- HM.toList foundmap + , count == HM.findWithDefault 0 nm totals] + + foundmap' = foldr HM.delete foundmap (map fst saturated) + + lterm = LExpr (map snd saturated) term' + + in case HM.findWithDefault 0 (SomeNameFor name) totals of + 1 -> (FoundMap foundmap', lterm) + tot | tot > 1 -> (FoundMap (HM.insert (SomeNameFor name) (1, Just (Some lterm)) foundmap') + ,LExpr [] (LStub name)) + | otherwise -> error "Term does not exist, yet we have it in hand" + + +-- TODO: lower LExpr into a normal expression with let bindings. Every LStub +-- should correspond to some let-bound expression higher up in the tree (if it +-- does not, that's a bug), and should become a De Bruijn variable reference to +-- said let-bound expression. Lambdas should also get proper De Bruijn indices +-- instead of tags, and LVar is also a normal variable (referring to a +-- lambda-abstracted argument). |