1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
|
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeOperators #-}
module Data.Expr.SharingRecovery where
import Control.Applicative ((<|>))
import Control.Monad.Trans.State.Strict
import Data.Bifunctor (second)
import Data.GADT.Compare
import Data.Hashable
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HM
import Data.Some
import Data.Type.Equality
import GHC.StableName
import Numeric.Natural
import Unsafe.Coerce (unsafeCoerce)
import Data.StableName.Extra
class Functor1 f where
fmap1 :: (forall b. g b -> h b) -> f g a -> f h a
class Functor1 f => Traversable1 f where
traverse1 :: Applicative m => (forall b. g b -> m (h b)) -> f g a -> m (f h a)
-- | Expression in parametric higher-order abstract syntax form
data PHOASExpr v f t where
PHOASOp :: f (PHOASExpr v f) t -> PHOASExpr v f t
PHOASLam :: (PHOASExpr v f a -> PHOASExpr v f b) -> PHOASExpr v f (a -> b)
PHOASVar :: v t -> PHOASExpr v f t
newtype Tag t = Tag Natural
deriving (Show, Eq)
newtype NameFor f t = NameFor (StableName (PHOASExpr Tag f t))
deriving (Eq)
deriving (Hashable) via (StableName (f (PHOASExpr Tag f) t))
instance GEq (NameFor f) where
geq (NameFor n1) (NameFor n2)
| eqStableName n1 n2 = Just unsafeCoerceRefl
| otherwise = Nothing
where
unsafeCoerceRefl :: a :~: b -- restricted version of unsafeCoerce that only allows punting proofs
unsafeCoerceRefl = unsafeCoerce Refl
-- | Pruned expression
data PExpr f t where
PStub :: NameFor f t -> PExpr f t
POp :: NameFor f t -> f (PExpr f) t -> PExpr f t
PLam :: NameFor f (a -> b) -> Tag a -> PExpr f b -> PExpr f (a -> b)
PVar :: Tag a -> PExpr f a
data SomeNameFor f = forall t. SomeNameFor {-# UNPACK #-} !(NameFor f t)
instance Eq (SomeNameFor f) where
SomeNameFor (NameFor n1) == SomeNameFor (NameFor n2) = eqStableName n1 n2
instance Hashable (SomeNameFor f) where
hashWithSalt salt (SomeNameFor name) = hashWithSalt salt name
type OccMap f = HashMap (SomeNameFor f) Natural
pruneExpr :: Traversable1 f => (forall v. PHOASExpr v f t) -> (OccMap f, PExpr f t)
pruneExpr term =
let (term', (_, mp)) = runState (pruneExpr' term) (0, mempty)
in (mp, term')
pruneExpr' :: Traversable1 f => PHOASExpr Tag f t -> State (Natural, OccMap f) (PExpr f t)
pruneExpr' orig@(PHOASOp args) = do
let name = makeStableName' orig
occmap <- gets snd
let (seenBefore, occmap') =
HM.alterF (\case Nothing -> (False, Just 1)
Just n -> (True, Just (n + 1)))
(SomeNameFor (NameFor name))
occmap
modify (second (const occmap'))
if seenBefore
then pure $ PStub (NameFor name)
else POp (NameFor name) <$> traverse1 pruneExpr' args
pruneExpr' orig@(PHOASLam f) = do
let name = makeStableName' orig
tag <- state (\(i, mp) -> (Tag i, (i + 1, mp)))
let body = f (PHOASVar tag)
PLam (NameFor name) tag <$> pruneExpr' body
pruneExpr' (PHOASVar tag) = pure $ PVar tag
-- | Lifted expression: a bunch of to-be let bound expressions on top of an LExpr'
data LExpr f t = LExpr [Some (LExpr f)] (LExpr' f t)
data LExpr' f t where -- TODO: this could be an instantiation of (a generalisation of) PExpr
LStub :: NameFor f t -> LExpr' f t
LOp :: NameFor f t -> f (LExpr f) t -> LExpr' f t
LLam :: NameFor f (a -> b) -> Tag a -> LExpr f b -> LExpr' f (a -> b)
LVar :: Tag a -> LExpr' f a
liftExpr :: Traversable1 f => OccMap f -> PExpr f t -> LExpr f t
liftExpr totals term =
let (_, e) = liftExpr' totals term
in e
newtype FoundMap f = FoundMap
(HashMap (SomeNameFor f) (Natural -- how many times seen
,Maybe (Some (LExpr f)))) -- the lifted subterm (once seen)
instance Semigroup (FoundMap f) where
FoundMap m1 <> FoundMap m2 = FoundMap $
HM.unionWith (\(n1, me1) (n2, me2) -> (n1 + n2, me1 <|> me2)) m1 m2
instance Monoid (FoundMap f) where
mempty = FoundMap HM.empty
liftExpr' :: Traversable1 f => OccMap f -> PExpr f t -> (FoundMap f, LExpr f t)
liftExpr' _totals (PStub name) =
(FoundMap $ HM.singleton (SomeNameFor name) (1, Just (Some (LExpr [] (LStub name))))
,LExpr [] (LStub name))
liftExpr' _totals (PVar tag) = (mempty, LExpr [] (LVar tag))
liftExpr' totals term =
let (FoundMap foundmap, name, term') = case term of
POp n args -> let (fm, args') = traverse1 (liftExpr' totals) args
in (fm, n, LOp n args')
PLam n tag body -> let (fm, body') = liftExpr' totals body
in (fm, n, LLam n tag body')
saturated = [case mterm of
Just t -> (nm, t)
Nothing -> error "Name saturated but no term found"
| (nm, (count, mterm)) <- HM.toList foundmap
, count == HM.findWithDefault 0 nm totals]
foundmap' = foldr HM.delete foundmap (map fst saturated)
lterm = LExpr (map snd saturated) term'
in case HM.findWithDefault 0 (SomeNameFor name) totals of
1 -> (FoundMap foundmap', lterm)
tot | tot > 1 -> (FoundMap (HM.insert (SomeNameFor name) (1, Just (Some lterm)) foundmap')
,LExpr [] (LStub name))
| otherwise -> error "Term does not exist, yet we have it in hand"
-- TODO: lower LExpr into a normal expression with let bindings. Every LStub
-- should correspond to some let-bound expression higher up in the tree (if it
-- does not, that's a bug), and should become a De Bruijn variable reference to
-- said let-bound expression. Lambdas should also get proper De Bruijn indices
-- instead of tags, and LVar is also a normal variable (referring to a
-- lambda-abstracted argument).
|