{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module Data.Expr.SharingRecovery where import Control.Applicative ((<|>)) import Control.Monad.Trans.State.Strict import Data.Bifunctor (first, second) import Data.Char (chr, ord) import Data.Functor.Const import Data.Functor.Identity import Data.Functor.Product import Data.Hashable import Data.HashMap.Strict (HashMap) import qualified Data.HashMap.Strict as HM import Data.List (sortBy, intersperse) import Data.Maybe (fromMaybe) import Data.Ord (comparing) import Data.Some import Data.Type.Equality import GHC.StableName import Numeric.Natural import Unsafe.Coerce (unsafeCoerce) -- import Debug.Trace import Data.StableName.Extra -- TODO: This implementation needs extensive documentation. 1. It is written -- quite generically, meaning that the actual algorithm is easily obscured to -- all but the most willing readers; 2. the original paper leaves something to -- be desired in the domain of effective explanation for programmers, and this -- is a good opportunity to try to do better. withMoreState :: Functor m => b -> StateT (s, b) m a -> StateT s m (a, b) withMoreState b0 (StateT f) = StateT $ \s -> (\(x, (s2, b)) -> ((x, b), s2)) <$> f (s, b0) withLessState :: Functor m => (s -> (s', b)) -> (s' -> b -> s) -> StateT s' m a -> StateT s m a withLessState split restore (StateT f) = StateT $ \s -> let (s', b) = split s in second (flip restore b) <$> f s' class Functor1 f where fmap1 :: (forall b. g b -> h b) -> f g a -> f h a default fmap1 :: Traversable1 f => (forall b. g b -> h b) -> f g a -> f h a fmap1 f x = runIdentity (traverse1 (Identity . f) x) class Functor1 f => Traversable1 f where traverse1 :: Applicative m => (forall b. g b -> m (h b)) -> f g a -> m (f h a) -- | Expression in parametric higher-order abstract syntax form data PHOASExpr typ v f t where PHOASOp :: typ t -> f (PHOASExpr typ v f) t -> PHOASExpr typ v f t PHOASLam :: typ (a -> b) -> typ a -> (v a -> PHOASExpr typ v f b) -> PHOASExpr typ v f (a -> b) PHOASVar :: typ t -> v t -> PHOASExpr typ v f t newtype Tag t = Tag Natural deriving (Show, Eq) deriving (Hashable) via Natural newtype NameFor typ f t = NameFor (StableName (PHOASExpr typ Tag f t)) deriving (Eq) deriving (Hashable) via (StableName (PHOASExpr typ Tag f t)) instance TestEquality (NameFor typ f) where testEquality (NameFor n1) (NameFor n2) | eqStableName n1 n2 = Just unsafeCoerceRefl | otherwise = Nothing where unsafeCoerceRefl :: a :~: b -- restricted version of unsafeCoerce that only allows punting proofs unsafeCoerceRefl = unsafeCoerce Refl -- | Pruned expression. -- -- Note that variables do not, and will never, have a name: we don't bother -- detecting sharing for variable references, because that would only introduce -- a redundant variable indirection. data PExpr typ f t where PStub :: NameFor typ f t -> typ t -> PExpr typ f t POp :: NameFor typ f t -> typ t -> f (PExpr typ f) t -> PExpr typ f t PLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> PExpr typ f b -> PExpr typ f (a -> b) PVar :: typ a -> Tag a -> PExpr typ f a data SomeNameFor typ f = forall t. SomeNameFor {-# UNPACK #-} !(NameFor typ f t) instance Eq (SomeNameFor typ f) where SomeNameFor (NameFor n1) == SomeNameFor (NameFor n2) = eqStableName n1 n2 instance Hashable (SomeNameFor typ f) where hashWithSalt salt (SomeNameFor name) = hashWithSalt salt name prettyPExpr :: Traversable1 f => Int -> PExpr typ f t -> ShowS prettyPExpr d = \case PStub (NameFor name) _ -> showString (showStableName name) POp (NameFor name) _ args -> let (argslist, _) = traverse1 (\arg -> ([Some arg], Const ())) args argslist' = map (\(Some arg) -> prettyPExpr 0 arg) argslist in showParen (d > 10) $ showString ("<" ++ showStableName name ++ ">(") . foldr (.) id (intersperse (showString ", ") argslist') . showString ")" PLam (NameFor name) _ _ (Tag tag) body -> showParen (d > 0) $ showString ("λ" ++ showStableName name ++ " x" ++ show tag ++ ". ") . prettyPExpr 0 body PVar _ (Tag tag) -> showString ("x" ++ show tag) -- | For each name: -- -- 1. The number of times the name is visited in a preorder traversal of the -- PHOAS expression, excluding children of nodes upon second or later visit. -- That is to say: only the nodes that are visited in a preorder traversal -- that skips repeated subtrees, are counted. -- 2. The height of the expression indicated by the name. -- -- Missing names have not been seen yet, and have unknown height. type OccMap typ f = HashMap (SomeNameFor typ f) (Natural, Natural) pruneExpr :: Traversable1 f => (forall v. PHOASExpr typ v f t) -> (OccMap typ f, PExpr typ f t) pruneExpr term = let ((term', _), (_, mp)) = runState (pruneExpr' term) (0, mempty) in (mp, term') -- | Returns pruned expression with its height. pruneExpr' :: Traversable1 f => PHOASExpr typ Tag f t -> State (Natural, OccMap typ f) (PExpr typ f t, Natural) pruneExpr' = \case orig@(PHOASOp ty args) -> do let name = makeStableName' orig mheight <- gets (fmap snd . HM.lookup (SomeNameFor (NameFor name)) . snd) case mheight of -- already visited Just height -> do modify (second (HM.adjust (first (+1)) (SomeNameFor (NameFor name)))) pure (PStub (NameFor name) ty, height) -- first visit Nothing -> do -- Traverse the arguments, collecting the maximum height in an -- additional piece of state. (args', maxhei) <- withMoreState 0 $ traverse1 (\arg -> do (arg', hei) <- withLessState id (,) (pruneExpr' arg) modify (second (hei `max`)) return arg') args -- Record this node modify (second (HM.insert (SomeNameFor (NameFor name)) (1, 1 + maxhei))) pure (POp (NameFor name) ty args', 1 + maxhei) orig@(PHOASLam tyf tyarg f) -> do let name = makeStableName' orig mheight <- gets (fmap snd . HM.lookup (SomeNameFor (NameFor name)) . snd) case mheight of -- already visited Just height -> do modify (second (HM.adjust (first (+1)) (SomeNameFor (NameFor name)))) pure (PStub (NameFor name) tyf, height) -- first visit Nothing -> do tag <- Tag <$> gets fst modify (first (+1)) let body = f tag (body', bodyhei) <- pruneExpr' body modify (second (HM.insert (SomeNameFor (NameFor name)) (1, 1 + bodyhei))) pure (PLam (NameFor name) tyf tyarg tag body', 1 + bodyhei) PHOASVar ty tag -> pure (PVar ty tag, 1) -- TODO: Replace "lift" with "float" -- | Lifted expression: a bunch of to-be let bound expressions on top of an -- LExpr'. Because LExpr' is really just PExpr with the recursive positions -- replaced by LExpr, LExpr should be seen as PExpr with a bunch of to-be let -- bound expressions on top of every node. data LExpr typ f t = LExpr [Some (LExpr typ f)] (LExpr' typ f t) data LExpr' typ f t where -- TODO: this could be an instantiation of (a generalisation of) PExpr LStub :: NameFor typ f t -> typ t -> LExpr' typ f t LOp :: NameFor typ f t -> typ t -> f (LExpr typ f) t -> LExpr' typ f t LLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> LExpr typ f b -> LExpr' typ f (a -> b) LVar :: typ a -> Tag a -> LExpr' typ f a prettyLExpr :: Traversable1 f => Int -> LExpr typ f t -> ShowS prettyLExpr d (LExpr [] e) = prettyLExpr' d e prettyLExpr d (LExpr lifted e) = showString "[" . foldr (.) id (intersperse (showString ", ") (map (\(Some e') -> prettyLExpr 0 e') lifted)) . showString "] " . prettyLExpr' d e prettyLExpr' :: Traversable1 f => Int -> LExpr' typ f t -> ShowS prettyLExpr' d = \case LStub (NameFor name) _ -> showString (showStableName name) LOp (NameFor name) _ args -> let (argslist, _) = traverse1 (\arg -> ([Some arg], Const ())) args argslist' = map (\(Some arg) -> prettyLExpr 0 arg) argslist in showParen (d > 10) $ showString ("<" ++ showStableName name ++ ">(") . foldr (.) id (intersperse (showString ", ") argslist') . showString ")" LLam (NameFor name) _ _ (Tag tag) body -> showParen (d > 0) $ showString ("λ" ++ showStableName name ++ " x" ++ show tag ++ ". ") . prettyLExpr 0 body LVar _ (Tag tag) -> showString ("x" ++ show tag) liftExpr :: Traversable1 f => OccMap typ f -> PExpr typ f t -> LExpr typ f t liftExpr totals term = snd (liftExpr' totals term) newtype FoundMap typ f = FoundMap (HashMap (SomeNameFor typ f) (Natural -- how many times seen ,Maybe (Some (LExpr typ f), Natural))) -- the lifted subterm with its height (once seen) instance Semigroup (FoundMap typ f) where FoundMap m1 <> FoundMap m2 = FoundMap $ HM.unionWith (\(n1, me1) (n2, me2) -> (n1 + n2, me1 <|> me2)) m1 m2 instance Monoid (FoundMap typ f) where mempty = FoundMap HM.empty liftExpr' :: Traversable1 f => OccMap typ f -> PExpr typ f t -> (FoundMap typ f, LExpr typ f t) liftExpr' _totals (PStub name ty) = -- trace ("Found stub: " ++ (case name of NameFor n -> showStableName n)) $ (FoundMap $ HM.singleton (SomeNameFor name) (1, Nothing) ,LExpr [] (LStub name ty)) liftExpr' _totals (PVar ty tag) = -- trace ("Found var: " ++ show tag) $ (mempty, LExpr [] (LVar ty tag)) liftExpr' totals term = let (FoundMap foundmap, name, termty, term') = case term of POp n ty args -> let (fm, args') = traverse1 (liftExpr' totals) args in (fm, n, ty, LOp n ty args') PLam n tyf tyarg tag body -> let (fm, body') = liftExpr' totals body in (fm, n, tyf, LLam n tyf tyarg tag body') -- TODO: perhaps this HM.toList together with the foldr HM.delete can be a single traversal of the HashMap saturated = [case mterm of Just t -> (nm, t) Nothing -> case nm of SomeNameFor (NameFor n) -> error $ "Name saturated (count=" ++ show count ++ ", totalcount=" ++ show totalcount ++ ") but no term found: " ++ showStableName n | (nm, (count, mterm)) <- HM.toList foundmap , let totalcount = fromMaybe 0 (fst <$> HM.lookup nm totals) , count == totalcount] foundmap' = foldr HM.delete foundmap (map fst saturated) lterm = LExpr (map fst (sortBy (comparing snd) (map snd saturated))) term' in case HM.findWithDefault (0, undefined) (SomeNameFor name) totals of (1, _) -> (FoundMap foundmap', lterm) (tot, height) | tot > 1 -> -- trace ("Inserting " ++ (case name of NameFor n -> showStableName n) ++ " into foundmap") $ (FoundMap (HM.insert (SomeNameFor name) (1, Just (Some lterm, height)) foundmap') ,LExpr [] (LStub name termty)) | otherwise -> error "Term does not exist, yet we have it in hand" -- | Untyped De Bruijn expression. No more names: there are lets now, and -- variable references are De Bruijn indices. These indices are not type-safe -- yet, though. data UBExpr typ f t where UBOp :: typ t -> f (UBExpr typ f) t -> UBExpr typ f t UBLam :: typ (a -> b) -> typ a -> UBExpr typ f b -> UBExpr typ f (a -> b) UBLet :: typ a -> UBExpr typ f a -> UBExpr typ f b -> UBExpr typ f b -- | De Bruijn index UBVar :: typ t -> Int -> UBExpr typ f t lowerExpr :: Functor1 f => LExpr typ f t -> UBExpr typ f t lowerExpr = lowerExpr' mempty mempty 0 data SomeTag = forall t. SomeTag (Tag t) instance Eq SomeTag where SomeTag (Tag n) == SomeTag (Tag m) = n == m instance Hashable SomeTag where hashWithSalt salt (SomeTag tag) = hashWithSalt salt tag lowerExpr' :: forall typ f t. Functor1 f => HashMap (SomeNameFor typ f) Int -- ^ node |-> De Bruijn level of defining binding -> HashMap SomeTag Int -- ^ tag |-> De Bruijn level of defining binding -> Int -- ^ Number of variables already in scope -> LExpr typ f t -> UBExpr typ f t lowerExpr' namelvl taglvl curlvl (LExpr lifted ex) = let (namelvl', prefix) = buildPrefix namelvl curlvl lifted curlvl' = curlvl + length lifted in prefix $ case ex of LStub name ty -> case HM.lookup (SomeNameFor name) namelvl' of Just lvl -> UBVar ty (curlvl - lvl - 1) Nothing -> error "Name variable out of scope" LOp _ ty args -> UBOp ty (fmap1 (lowerExpr' namelvl' taglvl curlvl') args) LLam _ tyf tyarg tag body -> UBLam tyf tyarg (lowerExpr' namelvl' (HM.insert (SomeTag tag) curlvl' taglvl) (curlvl' + 1) body) LVar ty tag -> case HM.lookup (SomeTag tag) taglvl of Just lvl -> UBVar ty (curlvl - lvl - 1) Nothing -> error "Tag variable out of scope" where buildPrefix :: forall b. HashMap (SomeNameFor typ f) Int -> Int -> [Some (LExpr typ f)] -> (HashMap (SomeNameFor typ f) Int, UBExpr typ f b -> UBExpr typ f b) buildPrefix namelvl' _ [] = (namelvl', id) buildPrefix namelvl' lvl (Some rhs@(LExpr _ rhs') : rhss) = let name = case rhs' of LStub n _ -> n LOp n _ _ -> n LLam n _ _ _ _ -> n LVar _ _ -> error "Recovering sharing of a tag is useless" ty = case rhs' of LStub{} -> error "Recovering sharing of a stub is useless" LOp _ t _ -> t LLam _ t _ _ _ -> t LVar t _ -> t prefix = UBLet ty (lowerExpr' namelvl' taglvl lvl rhs) in (prefix .) <$> buildPrefix (HM.insert (SomeNameFor name) lvl namelvl') (lvl + 1) rhss -- | A typed De Bruijn index. data Idx env t where IZ :: Idx (t : env) t IS :: Idx env t -> Idx (s : env) t deriving instance Show (Idx env t) data Env env f where ETop :: Env '[] f EPush :: Env env f -> f t -> Env (t : env) f envLookup :: Idx env t -> Env env f -> f t envLookup IZ (EPush _ x) = x envLookup (IS i) (EPush e _) = envLookup i e -- | Untyped lookup in an 'Env'. envLookupU :: Int -> Env env f -> Maybe (Some (Product f (Idx env))) envLookupU = go id where go :: (forall a. Idx env a -> Idx env' a) -> Int -> Env env f -> Maybe (Some (Product f (Idx env'))) go !_ !_ ETop = Nothing go f 0 (EPush _ t) = Just (Some (Pair t (f IZ))) go f i (EPush e _) = go (f . IS) (i - 1) e -- | Typed De Bruijn expression. This is the final result of sharing recovery. data BExpr typ env f t where BOp :: typ t -> f (BExpr typ env f) t -> BExpr typ env f t BLam :: typ (a -> b) -> typ a -> BExpr typ (a : env) f b -> BExpr typ env f (a -> b) BLet :: typ a -> BExpr typ env f a -> BExpr typ (a : env) f b -> BExpr typ env f b BVar :: typ t -> Idx env t -> BExpr typ env f t deriving instance (forall a. Show (typ a), forall a r. (forall b. Show (r b)) => Show (f r a)) => Show (BExpr typ env f t) prettyBExpr :: (forall m env' a. Monad m => (forall b. Int -> BExpr typ env' f b -> m ShowS) -> Int -> f (BExpr typ env' f) a -> m ShowS) -> BExpr typ '[] f t -> String prettyBExpr prettyOp e = evalState (prettyBExpr' prettyOp ETop 0 e) 0 "" prettyBExpr' :: (forall m env' a. Monad m => (forall b. Int -> BExpr typ env' f b -> m ShowS) -> Int -> f (BExpr typ env' f) a -> m ShowS) -> Env env (Const String) -> Int -> BExpr typ env f t -> State Int ShowS prettyBExpr' prettyOp env d = \case BOp _ args -> prettyOp (prettyBExpr' prettyOp env) d args BLam _ _ body -> do name <- genName body' <- prettyBExpr' prettyOp (EPush env (Const name)) 0 body return $ showParen (d > 0) $ showString ("λ" ++ name ++ ". ") . body' BLet _ rhs body -> do name <- genName rhs' <- prettyBExpr' prettyOp env 0 rhs body' <- prettyBExpr' prettyOp (EPush env (Const name)) 0 body return $ showParen (d > 0) $ showString ("let " ++ name ++ " = ") . rhs' . showString " in " . body' BVar _ idx -> return $ showString (getConst (envLookup idx env)) where genName = do i <- state (\i -> (i, i + 1)) return $ if i < 26 then [chr (ord 'a' + i)] else 'x' : show i retypeExpr :: (Functor1 f, TestEquality typ) => UBExpr typ f t -> BExpr typ '[] f t retypeExpr = retypeExpr' ETop retypeExpr' :: (Functor1 f, TestEquality typ) => Env env typ -> UBExpr typ f t -> BExpr typ env f t retypeExpr' env (UBOp ty args) = BOp ty (fmap1 (retypeExpr' env) args) retypeExpr' env (UBLam tyf tyarg body) = BLam tyf tyarg (retypeExpr' (EPush env tyarg) body) retypeExpr' env (UBLet ty rhs body) = BLet ty (retypeExpr' env rhs) (retypeExpr' (EPush env ty) body) retypeExpr' env (UBVar ty idx) = case envLookupU idx env of Just (Some (Pair defty tidx)) -> case testEquality ty defty of Just Refl -> BVar ty tidx Nothing -> error "Type mismatch in untyped De Bruijn expression" Nothing -> error "Untyped De Bruijn index out of range" sharingRecovery :: (Traversable1 f, TestEquality typ) => (forall v. PHOASExpr typ v f t) -> BExpr typ '[] f t sharingRecovery e = let (occmap, pexpr) = pruneExpr e lexpr = liftExpr occmap pexpr ubexpr = lowerExpr lexpr in -- trace ("PExpr: " ++ prettyPExpr 0 pexpr "") $ -- trace ("LExpr: " ++ prettyLExpr 0 lexpr "") $ retypeExpr ubexpr