{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TypeOperators #-} module Data.Expr.SharingRecovery where import Control.Applicative ((<|>)) import Control.Monad.Trans.State.Strict import Data.Bifunctor (second) import Data.GADT.Compare import Data.Hashable import Data.HashMap.Strict (HashMap) import qualified Data.HashMap.Strict as HM import Data.Some import Data.Type.Equality import GHC.StableName import Numeric.Natural import Unsafe.Coerce (unsafeCoerce) import Data.StableName.Extra class Functor1 f where fmap1 :: (forall b. g b -> h b) -> f g a -> f h a class Functor1 f => Traversable1 f where traverse1 :: Applicative m => (forall b. g b -> m (h b)) -> f g a -> m (f h a) -- | Expression in parametric higher-order abstract syntax form data PHOASExpr v f t where PHOASOp :: f (PHOASExpr v f) t -> PHOASExpr v f t PHOASLam :: (PHOASExpr v f a -> PHOASExpr v f b) -> PHOASExpr v f (a -> b) PHOASVar :: v t -> PHOASExpr v f t newtype Tag t = Tag Natural deriving (Show, Eq) newtype NameFor f t = NameFor (StableName (PHOASExpr Tag f t)) deriving (Eq) deriving (Hashable) via (StableName (f (PHOASExpr Tag f) t)) instance GEq (NameFor f) where geq (NameFor n1) (NameFor n2) | eqStableName n1 n2 = Just unsafeCoerceRefl | otherwise = Nothing where unsafeCoerceRefl :: a :~: b -- restricted version of unsafeCoerce that only allows punting proofs unsafeCoerceRefl = unsafeCoerce Refl -- | Pruned expression data PExpr f t where PStub :: NameFor f t -> PExpr f t POp :: NameFor f t -> f (PExpr f) t -> PExpr f t PLam :: NameFor f (a -> b) -> Tag a -> PExpr f b -> PExpr f (a -> b) PVar :: Tag a -> PExpr f a data SomeNameFor f = forall t. SomeNameFor {-# UNPACK #-} !(NameFor f t) instance Eq (SomeNameFor f) where SomeNameFor (NameFor n1) == SomeNameFor (NameFor n2) = eqStableName n1 n2 instance Hashable (SomeNameFor f) where hashWithSalt salt (SomeNameFor name) = hashWithSalt salt name type OccMap f = HashMap (SomeNameFor f) Natural pruneExpr :: Traversable1 f => (forall v. PHOASExpr v f t) -> (OccMap f, PExpr f t) pruneExpr term = let (term', (_, mp)) = runState (pruneExpr' term) (0, mempty) in (mp, term') pruneExpr' :: Traversable1 f => PHOASExpr Tag f t -> State (Natural, OccMap f) (PExpr f t) pruneExpr' orig@(PHOASOp args) = do let name = makeStableName' orig occmap <- gets snd let (seenBefore, occmap') = HM.alterF (\case Nothing -> (False, Just 1) Just n -> (True, Just (n + 1))) (SomeNameFor (NameFor name)) occmap modify (second (const occmap')) if seenBefore then pure $ PStub (NameFor name) else POp (NameFor name) <$> traverse1 pruneExpr' args pruneExpr' orig@(PHOASLam f) = do let name = makeStableName' orig tag <- state (\(i, mp) -> (Tag i, (i + 1, mp))) let body = f (PHOASVar tag) PLam (NameFor name) tag <$> pruneExpr' body pruneExpr' (PHOASVar tag) = pure $ PVar tag -- | Lifted expression: a bunch of to-be let bound expressions on top of an LExpr' data LExpr f t = LExpr [Some (LExpr f)] (LExpr' f t) data LExpr' f t where -- TODO: this could be an instantiation of (a generalisation of) PExpr LStub :: NameFor f t -> LExpr' f t LOp :: NameFor f t -> f (LExpr f) t -> LExpr' f t LLam :: NameFor f (a -> b) -> Tag a -> LExpr f b -> LExpr' f (a -> b) LVar :: Tag a -> LExpr' f a liftExpr :: Traversable1 f => OccMap f -> PExpr f t -> LExpr f t liftExpr totals term = let (_, e) = liftExpr' totals term in e newtype FoundMap f = FoundMap (HashMap (SomeNameFor f) (Natural -- how many times seen ,Maybe (Some (LExpr f)))) -- the lifted subterm (once seen) instance Semigroup (FoundMap f) where FoundMap m1 <> FoundMap m2 = FoundMap $ HM.unionWith (\(n1, me1) (n2, me2) -> (n1 + n2, me1 <|> me2)) m1 m2 instance Monoid (FoundMap f) where mempty = FoundMap HM.empty liftExpr' :: Traversable1 f => OccMap f -> PExpr f t -> (FoundMap f, LExpr f t) liftExpr' _totals (PStub name) = (FoundMap $ HM.singleton (SomeNameFor name) (1, Just (Some (LExpr [] (LStub name)))) ,LExpr [] (LStub name)) liftExpr' _totals (PVar tag) = (mempty, LExpr [] (LVar tag)) liftExpr' totals term = let (FoundMap foundmap, name, term') = case term of POp n args -> let (fm, args') = traverse1 (liftExpr' totals) args in (fm, n, LOp n args') PLam n tag body -> let (fm, body') = liftExpr' totals body in (fm, n, LLam n tag body') saturated = [case mterm of Just t -> (nm, t) Nothing -> error "Name saturated but no term found" | (nm, (count, mterm)) <- HM.toList foundmap , count == HM.findWithDefault 0 nm totals] foundmap' = foldr HM.delete foundmap (map fst saturated) lterm = LExpr (map snd saturated) term' in case HM.findWithDefault 0 (SomeNameFor name) totals of 1 -> (FoundMap foundmap', lterm) tot | tot > 1 -> (FoundMap (HM.insert (SomeNameFor name) (1, Just (Some lterm)) foundmap') ,LExpr [] (LStub name)) | otherwise -> error "Term does not exist, yet we have it in hand" -- TODO: lower LExpr into a normal expression with let bindings. Every LStub -- should correspond to some let-bound expression higher up in the tree (if it -- does not, that's a bug), and should become a De Bruijn variable reference to -- said let-bound expression. Lambdas should also get proper De Bruijn indices -- instead of tags, and LVar is also a normal variable (referring to a -- lambda-abstracted argument).