aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-08-01 22:04:52 +0200
committerTom Smeding <tom@tomsmeding.com>2024-08-01 22:04:52 +0200
commit250e3beae7a961fc740f775a563c303b4cc390fe (patch)
treeccbb8a090cdb082d86c0651935eeb986e2cddcea /src
Initial
Diffstat (limited to 'src')
-rw-r--r--src/Data/Expr/SharingRecovery.hs157
-rw-r--r--src/Data/StableName/Extra.hs17
2 files changed, 174 insertions, 0 deletions
diff --git a/src/Data/Expr/SharingRecovery.hs b/src/Data/Expr/SharingRecovery.hs
new file mode 100644
index 0000000..118df1c
--- /dev/null
+++ b/src/Data/Expr/SharingRecovery.hs
@@ -0,0 +1,157 @@
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE DerivingVia #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE TypeOperators #-}
+module Data.Expr.SharingRecovery where
+
+import Control.Applicative ((<|>))
+import Control.Monad.Trans.State.Strict
+import Data.Bifunctor (second)
+import Data.GADT.Compare
+import Data.Hashable
+import Data.HashMap.Strict (HashMap)
+import qualified Data.HashMap.Strict as HM
+import Data.Some
+import Data.Type.Equality
+import GHC.StableName
+import Numeric.Natural
+import Unsafe.Coerce (unsafeCoerce)
+
+import Data.StableName.Extra
+
+
+class Functor1 f where
+ fmap1 :: (forall b. g b -> h b) -> f g a -> f h a
+
+class Functor1 f => Traversable1 f where
+ traverse1 :: Applicative m => (forall b. g b -> m (h b)) -> f g a -> m (f h a)
+
+-- | Expression in parametric higher-order abstract syntax form
+data PHOASExpr v f t where
+ PHOASOp :: f (PHOASExpr v f) t -> PHOASExpr v f t
+ PHOASLam :: (PHOASExpr v f a -> PHOASExpr v f b) -> PHOASExpr v f (a -> b)
+ PHOASVar :: v t -> PHOASExpr v f t
+
+newtype Tag t = Tag Natural
+ deriving (Show, Eq)
+
+newtype NameFor f t = NameFor (StableName (PHOASExpr Tag f t))
+ deriving (Eq)
+ deriving (Hashable) via (StableName (f (PHOASExpr Tag f) t))
+
+instance GEq (NameFor f) where
+ geq (NameFor n1) (NameFor n2)
+ | eqStableName n1 n2 = Just unsafeCoerceRefl
+ | otherwise = Nothing
+ where
+ unsafeCoerceRefl :: a :~: b -- restricted version of unsafeCoerce that only allows punting proofs
+ unsafeCoerceRefl = unsafeCoerce Refl
+
+-- | Pruned expression
+data PExpr f t where
+ PStub :: NameFor f t -> PExpr f t
+ POp :: NameFor f t -> f (PExpr f) t -> PExpr f t
+ PLam :: NameFor f (a -> b) -> Tag a -> PExpr f b -> PExpr f (a -> b)
+ PVar :: Tag a -> PExpr f a
+
+data SomeNameFor f = forall t. SomeNameFor {-# UNPACK #-} !(NameFor f t)
+
+instance Eq (SomeNameFor f) where
+ SomeNameFor (NameFor n1) == SomeNameFor (NameFor n2) = eqStableName n1 n2
+
+instance Hashable (SomeNameFor f) where
+ hashWithSalt salt (SomeNameFor name) = hashWithSalt salt name
+
+type OccMap f = HashMap (SomeNameFor f) Natural
+
+pruneExpr :: Traversable1 f => (forall v. PHOASExpr v f t) -> (OccMap f, PExpr f t)
+pruneExpr term =
+ let (term', (_, mp)) = runState (pruneExpr' term) (0, mempty)
+ in (mp, term')
+
+pruneExpr' :: Traversable1 f => PHOASExpr Tag f t -> State (Natural, OccMap f) (PExpr f t)
+pruneExpr' orig@(PHOASOp args) = do
+ let name = makeStableName' orig
+ occmap <- gets snd
+ let (seenBefore, occmap') =
+ HM.alterF (\case Nothing -> (False, Just 1)
+ Just n -> (True, Just (n + 1)))
+ (SomeNameFor (NameFor name))
+ occmap
+ modify (second (const occmap'))
+ if seenBefore
+ then pure $ PStub (NameFor name)
+ else POp (NameFor name) <$> traverse1 pruneExpr' args
+
+pruneExpr' orig@(PHOASLam f) = do
+ let name = makeStableName' orig
+ tag <- state (\(i, mp) -> (Tag i, (i + 1, mp)))
+ let body = f (PHOASVar tag)
+ PLam (NameFor name) tag <$> pruneExpr' body
+
+pruneExpr' (PHOASVar tag) = pure $ PVar tag
+
+
+-- | Lifted expression: a bunch of to-be let bound expressions on top of an LExpr'
+data LExpr f t = LExpr [Some (LExpr f)] (LExpr' f t)
+data LExpr' f t where -- TODO: this could be an instantiation of (a generalisation of) PExpr
+ LStub :: NameFor f t -> LExpr' f t
+ LOp :: NameFor f t -> f (LExpr f) t -> LExpr' f t
+ LLam :: NameFor f (a -> b) -> Tag a -> LExpr f b -> LExpr' f (a -> b)
+ LVar :: Tag a -> LExpr' f a
+
+liftExpr :: Traversable1 f => OccMap f -> PExpr f t -> LExpr f t
+liftExpr totals term =
+ let (_, e) = liftExpr' totals term
+ in e
+
+newtype FoundMap f = FoundMap
+ (HashMap (SomeNameFor f) (Natural -- how many times seen
+ ,Maybe (Some (LExpr f)))) -- the lifted subterm (once seen)
+
+instance Semigroup (FoundMap f) where
+ FoundMap m1 <> FoundMap m2 = FoundMap $
+ HM.unionWith (\(n1, me1) (n2, me2) -> (n1 + n2, me1 <|> me2)) m1 m2
+
+instance Monoid (FoundMap f) where
+ mempty = FoundMap HM.empty
+
+liftExpr' :: Traversable1 f => OccMap f -> PExpr f t -> (FoundMap f, LExpr f t)
+liftExpr' _totals (PStub name) =
+ (FoundMap $ HM.singleton (SomeNameFor name) (1, Just (Some (LExpr [] (LStub name))))
+ ,LExpr [] (LStub name))
+
+liftExpr' _totals (PVar tag) = (mempty, LExpr [] (LVar tag))
+
+liftExpr' totals term =
+ let (FoundMap foundmap, name, term') = case term of
+ POp n args -> let (fm, args') = traverse1 (liftExpr' totals) args
+ in (fm, n, LOp n args')
+ PLam n tag body -> let (fm, body') = liftExpr' totals body
+ in (fm, n, LLam n tag body')
+
+ saturated = [case mterm of
+ Just t -> (nm, t)
+ Nothing -> error "Name saturated but no term found"
+ | (nm, (count, mterm)) <- HM.toList foundmap
+ , count == HM.findWithDefault 0 nm totals]
+
+ foundmap' = foldr HM.delete foundmap (map fst saturated)
+
+ lterm = LExpr (map snd saturated) term'
+
+ in case HM.findWithDefault 0 (SomeNameFor name) totals of
+ 1 -> (FoundMap foundmap', lterm)
+ tot | tot > 1 -> (FoundMap (HM.insert (SomeNameFor name) (1, Just (Some lterm)) foundmap')
+ ,LExpr [] (LStub name))
+ | otherwise -> error "Term does not exist, yet we have it in hand"
+
+
+-- TODO: lower LExpr into a normal expression with let bindings. Every LStub
+-- should correspond to some let-bound expression higher up in the tree (if it
+-- does not, that's a bug), and should become a De Bruijn variable reference to
+-- said let-bound expression. Lambdas should also get proper De Bruijn indices
+-- instead of tags, and LVar is also a normal variable (referring to a
+-- lambda-abstracted argument).
diff --git a/src/Data/StableName/Extra.hs b/src/Data/StableName/Extra.hs
new file mode 100644
index 0000000..f568740
--- /dev/null
+++ b/src/Data/StableName/Extra.hs
@@ -0,0 +1,17 @@
+{-# LANGUAGE BangPatterns #-}
+{-# OPTIONS_GHC -fno-full-laziness -fno-cse #-}
+module Data.StableName.Extra (
+ StableName,
+ makeStableName',
+) where
+
+import GHC.StableName
+import System.IO.Unsafe
+
+
+-- | This function evaluates its argument to WHNF and returns a stable name for
+-- the evaluation result. This function is not referentially transparent and is
+-- implemented using 'unsafePerformIO'.
+{-# NOINLINE makeStableName' #-}
+makeStableName' :: a -> StableName a
+makeStableName' !x = unsafePerformIO (makeStableName x)