{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} module Data.Expr.SharingRecovery where import Control.Applicative ((<|>)) import Control.Monad.Trans.State.Strict import Data.Bifunctor (second) import Data.GADT.Compare import Data.Hashable import Data.HashMap.Strict (HashMap) import qualified Data.HashMap.Strict as HM import Data.Some import Data.Type.Equality import GHC.StableName import Numeric.Natural import Unsafe.Coerce (unsafeCoerce) import Data.StableName.Extra -- TODO: This is not yet done, see the bottom of this file -- TODO: This implementation needs extensive documentation. 1. It is written -- quite generically, meaning that the actual algorithm is easily obscured to -- all but the most willing readers; 2. the original paper leaves something to -- be desired in the domain of effective explanation for programmers, and this -- is a good opportunity to try to do better. class Functor1 f where fmap1 :: (forall b. g b -> h b) -> f g a -> f h a class Functor1 f => Traversable1 f where traverse1 :: Applicative m => (forall b. g b -> m (h b)) -> f g a -> m (f h a) -- | Expression in parametric higher-order abstract syntax form data PHOASExpr typ v f t where PHOASOp :: typ t -> f (PHOASExpr typ v f) t -> PHOASExpr typ v f t PHOASLam :: typ (a -> b) -> typ a -> (PHOASExpr typ v f a -> PHOASExpr typ v f b) -> PHOASExpr typ v f (a -> b) PHOASVar :: typ t -> v t -> PHOASExpr typ v f t newtype Tag t = Tag Natural deriving (Show, Eq) newtype NameFor typ f t = NameFor (StableName (PHOASExpr typ Tag f t)) deriving (Eq) deriving (Hashable) via (StableName (PHOASExpr typ Tag f t)) instance GEq (NameFor typ f) where geq (NameFor n1) (NameFor n2) | eqStableName n1 n2 = Just unsafeCoerceRefl | otherwise = Nothing where unsafeCoerceRefl :: a :~: b -- restricted version of unsafeCoerce that only allows punting proofs unsafeCoerceRefl = unsafeCoerce Refl -- | Pruned expression. -- -- Note that variables do not, and will never, have a name: we don't bother -- detecting sharing for variable references, because that would only introduce -- a redundant variable indirection. data PExpr typ f t where PStub :: NameFor typ f t -> typ t -> PExpr typ f t POp :: NameFor typ f t -> typ t -> f (PExpr typ f) t -> PExpr typ f t PLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> PExpr typ f b -> PExpr typ f (a -> b) PVar :: typ a -> Tag a -> PExpr typ f a data SomeNameFor typ f = forall t. SomeNameFor {-# UNPACK #-} !(NameFor typ f t) instance Eq (SomeNameFor typ f) where SomeNameFor (NameFor n1) == SomeNameFor (NameFor n2) = eqStableName n1 n2 instance Hashable (SomeNameFor typ f) where hashWithSalt salt (SomeNameFor name) = hashWithSalt salt name -- | The number of times a particular name is visited in a preorder traversal -- of the PHOAS expression, excluding children of nodes upon second or later -- visit. That is to say: only the nodes that are visited in a preorder -- traversal that skips repeated subtrees, are counted. type OccMap typ f = HashMap (SomeNameFor typ f) Natural pruneExpr :: Traversable1 f => (forall v. PHOASExpr typ v f t) -> (OccMap typ f, PExpr typ f t) pruneExpr term = let (term', (_, mp)) = runState (pruneExpr' term) (0, mempty) in (mp, term') pruneExpr' :: Traversable1 f => PHOASExpr typ Tag f t -> State (Natural, OccMap typ f) (PExpr typ f t) pruneExpr' = \case orig@(PHOASOp ty args) -> do let name = makeStableName' orig seenBefore <- checkVisited name if seenBefore then pure $ PStub (NameFor name) ty else POp (NameFor name) ty <$> traverse1 pruneExpr' args orig@(PHOASLam tyf tyarg f) -> do let name = makeStableName' orig seenBefore <- checkVisited name if seenBefore then pure $ PStub (NameFor name) tyf else do tag <- state (\(i, mp) -> (Tag i, (i + 1, mp))) let body = f (PHOASVar tyarg tag) PLam (NameFor name) tyf tyarg tag <$> pruneExpr' body PHOASVar ty tag -> pure $ PVar ty tag where checkVisited name = do occmap <- gets snd let (seenBefore, occmap') = HM.alterF (\case Nothing -> (False, Just 1) Just n -> (True, Just (n + 1))) (SomeNameFor (NameFor name)) occmap modify (second (const occmap')) return seenBefore -- | Lifted expression: a bunch of to-be let bound expressions on top of an -- LExpr'. Because LExpr' is really just PExpr with the recursive positions -- replaced by LExpr, LExpr should be seen as PExpr with a bunch of to-be let -- bound expressions on top of every node. data LExpr typ f t = LExpr [Some (LExpr typ f)] (LExpr' typ f t) data LExpr' typ f t where -- TODO: this could be an instantiation of (a generalisation of) PExpr LStub :: NameFor typ f t -> typ t -> LExpr' typ f t LOp :: NameFor typ f t -> typ t -> f (LExpr typ f) t -> LExpr' typ f t LLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> LExpr typ f b -> LExpr' typ f (a -> b) LVar :: typ a -> Tag a -> LExpr' typ f a liftExpr :: Traversable1 f => OccMap typ f -> PExpr typ f t -> LExpr typ f t liftExpr totals term = snd (liftExpr' totals term) newtype FoundMap typ f = FoundMap (HashMap (SomeNameFor typ f) (Natural -- how many times seen ,Maybe (Some (LExpr typ f)))) -- the lifted subterm (once seen) instance Semigroup (FoundMap typ f) where FoundMap m1 <> FoundMap m2 = FoundMap $ HM.unionWith (\(n1, me1) (n2, me2) -> (n1 + n2, me1 <|> me2)) m1 m2 instance Monoid (FoundMap typ f) where mempty = FoundMap HM.empty liftExpr' :: Traversable1 f => OccMap typ f -> PExpr typ f t -> (FoundMap typ f, LExpr typ f t) liftExpr' _totals (PStub name ty) = (FoundMap $ HM.singleton (SomeNameFor name) (1, Nothing) -- Just (Some (LExpr [] (LStub name))) ,LExpr [] (LStub name ty)) liftExpr' _totals (PVar ty tag) = (mempty, LExpr [] (LVar ty tag)) liftExpr' totals term = let (FoundMap foundmap, name, termty, term') = case term of POp n ty args -> let (fm, args') = traverse1 (liftExpr' totals) args in (fm, n, ty, LOp n ty args') PLam n tyf tyarg tag body -> let (fm, body') = liftExpr' totals body in (fm, n, tyf, LLam n tyf tyarg tag body') -- TODO: perhaps this HM.toList together with the foldr HM.delete can be a single traversal of the HashMap saturated = [case mterm of Just t -> (nm, t) Nothing -> error "Name saturated but no term found" | (nm, (count, mterm)) <- HM.toList foundmap , count == HM.findWithDefault 0 nm totals] foundmap' = foldr HM.delete foundmap (map fst saturated) lterm = LExpr (map snd saturated) term' in case HM.findWithDefault 0 (SomeNameFor name) totals of 1 -> (FoundMap foundmap', lterm) tot | tot > 1 -> (FoundMap (HM.insert (SomeNameFor name) (1, Just (Some lterm)) foundmap') ,LExpr [] (LStub name termty)) | otherwise -> error "Term does not exist, yet we have it in hand" -- | Errors on stubs. lexprTypeOf :: LExpr typ f t -> typ t lexprTypeOf (LExpr _ e) = case e of LStub{} -> error "lexprTypeOf: got a stub" LOp _ t _ -> t LLam _ t _ _ _ -> t LVar t _ -> t -- TODO: lower LExpr into a normal expression with let bindings. Every LStub -- should correspond to some let-bound expression higher up in the tree (if it -- does not, that's a bug), and should become a De Bruijn variable reference to -- said let-bound expression. Lambdas should also get proper De Bruijn indices -- instead of tags, and LVar is also a normal variable (referring to a -- lambda-abstracted argument). -- | Untyped De Bruijn expression. No more names: there are lets now, and -- variable references are De Bruijn indices. These indices are not type-safe -- yet, though. data UBExpr typ f t where UBOp :: typ t -> f (UBExpr typ f) t -> UBExpr typ f t UBLam :: typ a -> UBExpr typ f b -> UBExpr typ f (a -> b) UBLet :: typ a -> UBExpr typ f a -> UBExpr typ f b -> UBExpr typ f b -- De Bruijn index UBVar :: typ t -> Int -> UBExpr typ f t lowerExpr :: LExpr typ f t -> UBExpr typ f t lowerExpr = lowerExpr' mempty 0 -- 1. name |-> De Bruijn level of the variable defining that name -- 2. Number of variables already in scope lowerExpr' :: forall typ f t. Traversable1 f => HashMap (SomeNameFor typ f) Int -> Int -> LExpr typ f t -> UBExpr typ f t lowerExpr' namelvl curlvl (LExpr lifted ex) = let prefix = buildPrefix curlvl lifted curlvl' = curlvl + length lifted in case ex of LStub name ty -> case HM.lookup (SomeNameFor name) namelvl of Just lvl -> UBVar ty (curlvl - lvl - 1) Nothing -> error "Variable out of scope" LOp name ty args -> UBOp ty (_ $ traverse1 _ args) where buildPrefix :: forall b. Int -> [Some (LExpr typ f)] -> UBExpr typ f b -> UBExpr typ f b buildPrefix _ [] = id buildPrefix lvl (Some rhs : rhss) = UBLet (lexprTypeOf rhs) (lowerExpr' namelvl lvl rhs) . buildPrefix (lvl + 1) rhss data Idx env t where IZ :: Idx (t : env) t IS :: Idx env t -> Idx (s : env) t