1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
|
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
module Data.Expr.SharingRecovery.Internal where
import Control.Applicative ((<|>))
import Control.Monad.Trans.State.Strict
import Data.Bifunctor (first, second)
import Data.Char (chr, ord)
import Data.Functor.Const
import Data.Functor.Identity
import Data.Functor.Product
import Data.Hashable
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HM
import Data.List (sortBy, intersperse)
import Data.Maybe (fromMaybe)
import Data.Ord (comparing)
import Data.Some
import Data.Type.Equality
import GHC.StableName
import Numeric.Natural
import Unsafe.Coerce (unsafeCoerce)
-- import Debug.Trace
import Data.StableName.Extra
-- TODO: This implementation needs extensive documentation. 1. It is written
-- quite generically, meaning that the actual algorithm is easily obscured to
-- all but the most willing readers; 2. the original paper leaves something to
-- be desired in the domain of effective explanation for programmers, and this
-- is a good opportunity to try to do better.
withMoreState :: Functor m => b -> StateT (s, b) m a -> StateT s m (a, b)
withMoreState b0 (StateT f) =
StateT $ \s -> (\(x, (s2, b)) -> ((x, b), s2)) <$> f (s, b0)
withLessState :: Functor m => (s -> (s', b)) -> (s' -> b -> s) -> StateT s' m a -> StateT s m a
withLessState split restore (StateT f) =
StateT $ \s -> let (s', b) = split s
in second (flip restore b) <$> f s'
-- | 'Functor' on the second-to-last type parameter.
class Functor1 f where
fmap1 :: (forall b. g b -> h b) -> f g a -> f h a
default fmap1 :: Traversable1 f => (forall b. g b -> h b) -> f g a -> f h a
fmap1 f x = runIdentity (traverse1 (Identity . f) x)
-- | 'Traversable' on the second-to-last type parameter.
class Functor1 f => Traversable1 f where
traverse1 :: Applicative m => (forall b. g b -> m (h b)) -> f g a -> m (f h a)
-- | Expression in parametric higher-order abstract syntax form.
--
-- * @typ@ should be a singleton GADT that describes the @t@ type parameter. It
-- should implement 'TestEquality'. For example, for a simple language with
-- only @Int@, pairs and functions as types, the following would suffice:
--
-- @
-- data Typ t where
-- TInt :: Typ Int
-- TPair :: Typ a -> Typ b -> Typ (a, b)
-- TFun :: Typ a -> Typ b -> Typ (a -> b)
-- @
--
-- * @v@ is the type of variables in the expression. A PHOAS expression is
-- required to be parametric in the @v@ parameter; the only place you will
-- obtain a @v@ is inside a @PHOASLam@ function body.
--
-- * @f@ should be your type of operations for your language. It is indexed by
-- the type of subexpressions and the result type of the operation; thus, it
-- is a "base functor" indexed by one additional parameter (@t@). For
-- example, for a simple language that supports only integer constants,
-- integer addition, lambda abstraction and function application:
--
-- @
-- data Oper r t where
-- OConst :: Int -> Oper r Int
-- OAdd :: r Int -> r Int -> Oper r Int
-- OApp :: r (a -> b) -> r a -> Oper r b
-- @
--
-- Note that lambda abstraction is not an operation, because 'PHOASExpr'
-- already represents lambda abstraction as 'PHOASLam'. The reason lambdas
-- are part of 'PHOASExpr' is that 'sharingRecovery' must be able to inspect
-- lambdas and analyse their bodies.
--
-- Note furthermore that @Oper@ is /not/ a recursive type. Subexpressions
-- are again 'PHOASExpr's, and 'sharingRecovery' needs to be able to see
-- them. Hence, you should call back to @r@ instead of recursing
-- manually.
--
-- * @t@ is the result type of this expression.
data PHOASExpr typ v f t where
PHOASOp :: typ t -> f (PHOASExpr typ v f) t -> PHOASExpr typ v f t
PHOASLam :: typ (a -> b) -> typ a -> (v a -> PHOASExpr typ v f b) -> PHOASExpr typ v f (a -> b)
PHOASVar :: typ t -> v t -> PHOASExpr typ v f t
newtype Tag t = Tag Natural
deriving (Show, Eq)
deriving (Hashable) via Natural
newtype NameFor typ f t = NameFor (StableName (PHOASExpr typ Tag f t))
deriving (Eq)
deriving (Hashable) via (StableName (PHOASExpr typ Tag f t))
instance TestEquality (NameFor typ f) where
testEquality (NameFor n1) (NameFor n2)
| eqStableName n1 n2 = Just unsafeCoerceRefl
| otherwise = Nothing
where
unsafeCoerceRefl :: a :~: b -- restricted version of unsafeCoerce that only allows punting proofs
unsafeCoerceRefl = unsafeCoerce Refl
-- | Pruned expression.
--
-- Note that variables do not, and will never, have a name: we don't bother
-- detecting sharing for variable references, because that would only introduce
-- a redundant variable indirection.
--
-- This is defined as a base functor; @r@ is the recursive position.
data PExpr r typ f t where
PStub :: NameFor typ f t -> typ t -> PExpr r typ f t
POp :: NameFor typ f t -> typ t -> f (r typ f) t -> PExpr r typ f t
PLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> r typ f b -> PExpr r typ f (a -> b)
PVar :: typ a -> Tag a -> PExpr r typ f a
-- | Fixpoint of 'PExpr'
newtype PExpr0 typ f t = PExpr0 (PExpr PExpr0 typ f t)
data SomeNameFor typ f = forall t. SomeNameFor {-# UNPACK #-} !(NameFor typ f t)
instance Eq (SomeNameFor typ f) where
SomeNameFor (NameFor n1) == SomeNameFor (NameFor n2) = eqStableName n1 n2
instance Hashable (SomeNameFor typ f) where
hashWithSalt salt (SomeNameFor name) = hashWithSalt salt name
prettyPExpr0 :: Traversable1 f => Int -> PExpr0 typ f t -> ShowS
prettyPExpr0 d (PExpr0 ex) = prettyPExpr prettyPExpr0 d ex
prettyPExpr :: Traversable1 f => (forall a. Int -> r typ f a -> ShowS) -> Int -> PExpr r typ f t -> ShowS
prettyPExpr recur d = \case
PStub (NameFor name) _ -> showString (showStableName name)
POp (NameFor name) _ args ->
let (argslist, _) = traverse1 (\arg -> ([Some arg], Const ())) args
argslist' = map (\(Some arg) -> recur 0 arg) argslist
in showParen (d > 10) $
showString ("<" ++ showStableName name ++ ">(")
. foldr (.) id (intersperse (showString ", ") argslist')
. showString ")"
PLam (NameFor name) _ _ (Tag tag) body ->
showParen (d > 0) $
showString ("λ" ++ showStableName name ++ " x" ++ show tag ++ ". ") . recur 0 body
PVar _ (Tag tag) -> showString ("x" ++ show tag)
-- | For each name:
--
-- 1. The number of times the name is visited in a preorder traversal of the
-- PHOAS expression, excluding children of nodes upon second or later visit.
-- That is to say: only the nodes that are visited in a preorder traversal
-- that skips repeated subtrees, are counted.
-- 2. The height of the expression indicated by the name.
--
-- Missing names have not been seen yet, and have unknown height.
type OccMap typ f = HashMap (SomeNameFor typ f) (Natural, Natural)
pruneExpr :: Traversable1 f => (forall v. PHOASExpr typ v f t) -> (OccMap typ f, PExpr0 typ f t)
pruneExpr term =
let ((term', _), (_, mp)) = runState (pruneExpr' term) (0, mempty)
in (mp, term')
-- | Returns pruned expression with its height.
-- State: (ID generator, occurrence map being accumulated)
pruneExpr' :: Traversable1 f => PHOASExpr typ Tag f t -> State (Natural, OccMap typ f) (PExpr0 typ f t, Natural)
pruneExpr' = \case
orig@(PHOASOp ty args) -> do
let name = makeStableName' orig
mheight <- gets (fmap snd . HM.lookup (SomeNameFor (NameFor name)) . snd)
case mheight of
-- already visited
Just height -> do
modify (second (HM.adjust (first (+1)) (SomeNameFor (NameFor name))))
pure (PExpr0 (PStub (NameFor name) ty), height)
-- first visit
Nothing -> do
-- Traverse the arguments, collecting the maximum height in an
-- additional piece of state.
(args', maxhei) <-
withMoreState 0 $
traverse1 (\arg -> do
-- drop the extra state for the recursive call
(arg', hei) <- withLessState id (,) (pruneExpr' arg)
modify (second (hei `max`)) -- modify the extra state
return arg')
args
-- Record this node
modify (second (HM.insert (SomeNameFor (NameFor name)) (1, 1 + maxhei)))
pure (PExpr0 (POp (NameFor name) ty args'), 1 + maxhei)
orig@(PHOASLam tyf tyarg f) -> do
let name = makeStableName' orig
mheight <- gets (fmap snd . HM.lookup (SomeNameFor (NameFor name)) . snd)
case mheight of
-- already visited
Just height -> do
modify (second (HM.adjust (first (+1)) (SomeNameFor (NameFor name))))
pure (PExpr0 (PStub (NameFor name) tyf), height)
-- first visit
Nothing -> do
tag <- Tag <$> gets fst
modify (first (+1))
let body = f tag
(body', bodyhei) <- pruneExpr' body
modify (second (HM.insert (SomeNameFor (NameFor name)) (1, 1 + bodyhei)))
pure (PExpr0 (PLam (NameFor name) tyf tyarg tag body'), 1 + bodyhei)
PHOASVar ty tag -> pure (PExpr0 (PVar ty tag), 1)
-- | Floated expression: again a 'PExpr' (it's a fixpoint over the same base
-- functor), but now with a bunch of to-be let bound expressions on top of
-- every node.
data LExpr typ f t = LExpr [Some (LExpr typ f)] (PExpr LExpr typ f t)
prettyLExpr :: Traversable1 f => Int -> LExpr typ f t -> ShowS
prettyLExpr d (LExpr [] e) = prettyPExpr prettyLExpr d e
prettyLExpr d (LExpr floated e) =
showString "["
. foldr (.) id (intersperse (showString ", ") (map (\(Some e') -> prettyLExpr 0 e') floated))
. showString "] " . prettyPExpr prettyLExpr d e
floatExpr :: Traversable1 f => OccMap typ f -> PExpr0 typ f t -> LExpr typ f t
floatExpr totals term = snd (floatExpr' totals term)
newtype FoundMap typ f = FoundMap
(HashMap (SomeNameFor typ f)
(Natural -- how many times seen
,Maybe (Some (LExpr typ f), Natural))) -- the floated subterm with its height (once seen)
instance Semigroup (FoundMap typ f) where
FoundMap m1 <> FoundMap m2 = FoundMap $
HM.unionWith (\(n1, me1) (n2, me2) -> (n1 + n2, me1 <|> me2)) m1 m2
instance Monoid (FoundMap typ f) where
mempty = FoundMap HM.empty
floatExpr' :: Traversable1 f => OccMap typ f -> PExpr0 typ f t -> (FoundMap typ f, LExpr typ f t)
floatExpr' totals (PExpr0 term) = case term of
PStub name ty ->
-- trace ("Found stub: " ++ (case name of NameFor n -> showStableName n)) $
(FoundMap $ HM.singleton (SomeNameFor name) (1, Nothing)
,LExpr [] (PStub name ty))
PVar ty tag ->
-- trace ("Found var: " ++ show tag) $
(mempty, LExpr [] (PVar ty tag))
_ ->
let (FoundMap foundmap, name, termty, term') = case term of
POp n ty args ->
let (fm, args') = traverse1 (floatExpr' totals) args
in (fm, n, ty, POp n ty args')
PLam n tyf tyarg tag body ->
let (fm, body') = floatExpr' totals body
in (fm, n, tyf, PLam n tyf tyarg tag body')
-- TODO: perhaps this HM.toList together with the foldr HM.delete can be a single traversal of the HashMap
saturated = [case mterm of
Just t -> (nm, t)
Nothing -> case nm of
SomeNameFor (NameFor n) ->
error $ "Name saturated (count=" ++ show count ++ ", totalcount=" ++ show totalcount ++ ") but no term found: " ++ showStableName n
| (nm, (count, mterm)) <- HM.toList foundmap
, let totalcount = fromMaybe 0 (fst <$> HM.lookup nm totals)
, count == totalcount]
foundmap' = foldr HM.delete foundmap (map fst saturated)
lterm = LExpr (map fst (sortBy (comparing snd) (map snd saturated))) term'
in case HM.findWithDefault (0, undefined) (SomeNameFor name) totals of
(1, _) -> (FoundMap foundmap', lterm)
(tot, height)
| tot > 1 -> -- trace ("Inserting " ++ (case name of NameFor n -> showStableName n) ++ " into foundmap") $
(FoundMap (HM.insert (SomeNameFor name) (1, Just (Some lterm, height)) foundmap')
,LExpr [] (PStub name termty))
| otherwise -> error "Term does not exist, yet we have it in hand"
-- | Untyped De Bruijn expression. No more names: there are lets now, and
-- variable references are De Bruijn indices. These indices are not type-safe
-- yet, though.
data UBExpr typ f t where
UBOp :: typ t -> f (UBExpr typ f) t -> UBExpr typ f t
UBLam :: typ (a -> b) -> typ a -> UBExpr typ f b -> UBExpr typ f (a -> b)
UBLet :: typ a -> UBExpr typ f a -> UBExpr typ f b -> UBExpr typ f b
-- | De Bruijn index
UBVar :: typ t -> Int -> UBExpr typ f t
lowerExpr :: Functor1 f => LExpr typ f t -> UBExpr typ f t
lowerExpr = lowerExpr' mempty mempty 0
data SomeTag = forall t. SomeTag (Tag t)
instance Eq SomeTag where
SomeTag (Tag n) == SomeTag (Tag m) = n == m
instance Hashable SomeTag where
hashWithSalt salt (SomeTag tag) = hashWithSalt salt tag
lowerExpr' :: forall typ f t. Functor1 f
=> HashMap (SomeNameFor typ f) Int -- ^ node |-> De Bruijn level of defining binding
-> HashMap SomeTag Int -- ^ tag |-> De Bruijn level of defining binding
-> Int -- ^ Number of variables already in scope
-> LExpr typ f t -> UBExpr typ f t
lowerExpr' namelvl taglvl curlvl (LExpr floated ex) =
let (namelvl', prefix) = buildPrefix namelvl curlvl floated
curlvl' = curlvl + length floated
in prefix $
case ex of
PStub name ty ->
case HM.lookup (SomeNameFor name) namelvl' of
Just lvl -> UBVar ty (curlvl - lvl - 1)
Nothing -> error "Name variable out of scope"
POp _ ty args ->
UBOp ty (fmap1 (lowerExpr' namelvl' taglvl curlvl') args)
PLam _ tyf tyarg tag body ->
UBLam tyf tyarg (lowerExpr' namelvl' (HM.insert (SomeTag tag) curlvl' taglvl) (curlvl' + 1) body)
PVar ty tag ->
case HM.lookup (SomeTag tag) taglvl of
Just lvl -> UBVar ty (curlvl - lvl - 1)
Nothing -> error "Tag variable out of scope"
where
buildPrefix :: forall b.
HashMap (SomeNameFor typ f) Int
-> Int
-> [Some (LExpr typ f)]
-> (HashMap (SomeNameFor typ f) Int, UBExpr typ f b -> UBExpr typ f b)
buildPrefix namelvl' _ [] = (namelvl', id)
buildPrefix namelvl' lvl (Some rhs@(LExpr _ rhs') : rhss) =
let name = case rhs' of
PStub n _ -> n
POp n _ _ -> n
PLam n _ _ _ _ -> n
PVar _ _ -> error "Recovering sharing of a tag is useless"
ty = case rhs' of
PStub{} -> error "Recovering sharing of a stub is useless"
POp _ t _ -> t
PLam _ t _ _ _ -> t
PVar t _ -> t
prefix = UBLet ty (lowerExpr' namelvl' taglvl lvl rhs)
in second (prefix .) $ buildPrefix (HM.insert (SomeNameFor name) lvl namelvl') (lvl + 1) rhss
-- | A typed De Bruijn index.
data Idx env t where
IZ :: Idx (t : env) t
IS :: Idx env t -> Idx (s : env) t
deriving instance Show (Idx env t)
data Env env f where
ETop :: Env '[] f
EPush :: Env env f -> f t -> Env (t : env) f
envLookup :: Idx env t -> Env env f -> f t
envLookup IZ (EPush _ x) = x
envLookup (IS i) (EPush e _) = envLookup i e
-- | Untyped lookup in an 'Env'.
envLookupU :: Int -> Env env f -> Maybe (Some (Product f (Idx env)))
envLookupU = go id
where
go :: (forall a. Idx env a -> Idx env' a) -> Int -> Env env f -> Maybe (Some (Product f (Idx env')))
go !_ !_ ETop = Nothing
go f 0 (EPush _ t) = Just (Some (Pair t (f IZ)))
go f i (EPush e _) = go (f . IS) (i - 1) e
-- | Typed De Bruijn expression. This is the result of sharing recovery. It is
-- not higher-order any more, and furthermore has explicit let-bindings ('BLet')
-- that denote the sharing inside the term. This is a normal AST.
--
-- * @env@ is a type-level list containing the types of all variables in scope
-- in the expression. The bottom-most variable (i.e. the one defined most
-- recently) is at the head of the list. 'Idx' is a De Bruijn index that
-- indexes into this list, to ensure that the whole expression is well-typed
-- and well-scoped.
--
-- * @typ@, @f@ and @t@ are exactly as in 'PHOASExpr'.
data BExpr typ env f t where
BOp :: typ t -> f (BExpr typ env f) t -> BExpr typ env f t
BLam :: typ (a -> b) -> typ a -> BExpr typ (a : env) f b -> BExpr typ env f (a -> b)
BLet :: typ a -> BExpr typ env f a -> BExpr typ (a : env) f b -> BExpr typ env f b
BVar :: typ t -> Idx env t -> BExpr typ env f t
deriving instance (forall a. Show (typ a), forall a r. (forall b. Show (r b)) => Show (f r a))
=> Show (BExpr typ env f t)
prettyBExpr :: (forall m env' a. Monad m => (forall b. Int -> BExpr typ env' f b -> m ShowS)
-> Int -> f (BExpr typ env' f) a -> m ShowS)
-> BExpr typ '[] f t -> String
prettyBExpr prettyOp e = evalState (prettyBExpr' prettyOp ETop 0 e) 0 ""
prettyBExpr' :: (forall m env' a. Monad m => (forall b. Int -> BExpr typ env' f b -> m ShowS)
-> Int -> f (BExpr typ env' f) a -> m ShowS)
-> Env env (Const String) -> Int -> BExpr typ env f t -> State Int ShowS
prettyBExpr' prettyOp env d = \case
BOp _ args ->
prettyOp (prettyBExpr' prettyOp env) d args
BLam _ _ body -> do
name <- genName
body' <- prettyBExpr' prettyOp (EPush env (Const name)) 0 body
return $ showParen (d > 0) $ showString ("λ" ++ name ++ ". ") . body'
BLet _ rhs body -> do
name <- genName
rhs' <- prettyBExpr' prettyOp env 0 rhs
body' <- prettyBExpr' prettyOp (EPush env (Const name)) 0 body
return $ showParen (d > 0) $ showString ("let " ++ name ++ " = ") . rhs' . showString " in " . body'
BVar _ idx ->
return $ showString (getConst (envLookup idx env))
where
genName = do
i <- state (\i -> (i, i + 1))
return $ if i < 26 then [chr (ord 'a' + i)] else 'x' : show i
retypeExpr :: (Functor1 f, TestEquality typ) => UBExpr typ f t -> BExpr typ '[] f t
retypeExpr = retypeExpr' ETop
retypeExpr' :: (Functor1 f, TestEquality typ) => Env env typ -> UBExpr typ f t -> BExpr typ env f t
retypeExpr' env (UBOp ty args) = BOp ty (fmap1 (retypeExpr' env) args)
retypeExpr' env (UBLam tyf tyarg body) = BLam tyf tyarg (retypeExpr' (EPush env tyarg) body)
retypeExpr' env (UBLet ty rhs body) = BLet ty (retypeExpr' env rhs) (retypeExpr' (EPush env ty) body)
retypeExpr' env (UBVar ty idx) =
case envLookupU idx env of
Just (Some (Pair defty tidx)) ->
case testEquality ty defty of
Just Refl -> BVar ty tidx
Nothing -> error "Type mismatch in untyped De Bruijn expression"
Nothing -> error "Untyped De Bruijn index out of range"
-- | By observing internal sharing using 'StableName's (in
-- 'System.IO.Unsafe.unsafePerformIO'), convert an expression in higher-order
-- abstract syntax form to a well-typed well-scoped De Bruijn expression with
-- explicit let-bindings.
sharingRecovery :: (Traversable1 f, TestEquality typ) => (forall v. PHOASExpr typ v f t) -> BExpr typ '[] f t
sharingRecovery e =
let (occmap, pexpr) = pruneExpr e
lexpr = floatExpr occmap pexpr
ubexpr = lowerExpr lexpr
in -- trace ("PExpr: " ++ prettyPExpr 0 pexpr "") $
-- trace ("LExpr: " ++ prettyLExpr 0 lexpr "") $
retypeExpr ubexpr
|