aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Expr/SharingRecovery.hs
blob: cdb64eb0ef759ceb31e622f5d6e8ba59dbb715df (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Data.Expr.SharingRecovery where

import Control.Applicative ((<|>))
import Control.Monad.Trans.State.Strict
import Data.Bifunctor (second)
import Data.Char (chr, ord)
import Data.Functor.Const
import Data.Functor.Identity
import Data.Functor.Product
import Data.Hashable
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HM
import Data.Some
import Data.Type.Equality
import GHC.StableName
import Numeric.Natural
import Unsafe.Coerce (unsafeCoerce)

import Data.StableName.Extra


-- TODO: This implementation needs extensive documentation. 1. It is written
-- quite generically, meaning that the actual algorithm is easily obscured to
-- all but the most willing readers; 2. the original paper leaves something to
-- be desired in the domain of effective explanation for programmers, and this
-- is a good opportunity to try to do better.


class Functor1 f where
  fmap1 :: (forall b. g b -> h b) -> f g a -> f h a

  default fmap1 :: Traversable1 f => (forall b. g b -> h b) -> f g a -> f h a
  fmap1 f x = runIdentity (traverse1 (Identity . f) x)

class Functor1 f => Traversable1 f where
  traverse1 :: Applicative m => (forall b. g b -> m (h b)) -> f g a -> m (f h a)

-- | Expression in parametric higher-order abstract syntax form
data PHOASExpr typ v f t where
  PHOASOp :: typ t -> f (PHOASExpr typ v f) t -> PHOASExpr typ v f t
  PHOASLam :: typ (a -> b) -> typ a -> (v a -> PHOASExpr typ v f b) -> PHOASExpr typ v f (a -> b)
  PHOASVar :: typ t -> v t -> PHOASExpr typ v f t

newtype Tag t = Tag Natural
  deriving (Show, Eq)
  deriving (Hashable) via Natural

newtype NameFor typ f t = NameFor (StableName (PHOASExpr typ Tag f t))
  deriving (Eq)
  deriving (Hashable) via (StableName (PHOASExpr typ Tag f t))

instance TestEquality (NameFor typ f) where
  testEquality (NameFor n1) (NameFor n2)
    | eqStableName n1 n2 = Just unsafeCoerceRefl
    | otherwise = Nothing
    where
      unsafeCoerceRefl :: a :~: b  -- restricted version of unsafeCoerce that only allows punting proofs
      unsafeCoerceRefl = unsafeCoerce Refl

-- | Pruned expression.
--
-- Note that variables do not, and will never, have a name: we don't bother
-- detecting sharing for variable references, because that would only introduce
-- a redundant variable indirection.
data PExpr typ f t where
  PStub :: NameFor typ f t -> typ t -> PExpr typ f t
  POp :: NameFor typ f t -> typ t -> f (PExpr typ f) t -> PExpr typ f t
  PLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> PExpr typ f b -> PExpr typ f (a -> b)
  PVar :: typ a -> Tag a -> PExpr typ f a

data SomeNameFor typ f = forall t. SomeNameFor {-# UNPACK #-} !(NameFor typ f t)

instance Eq (SomeNameFor typ f) where
  SomeNameFor (NameFor n1) == SomeNameFor (NameFor n2) = eqStableName n1 n2

instance Hashable (SomeNameFor typ f) where
  hashWithSalt salt (SomeNameFor name) = hashWithSalt salt name

-- | The number of times a particular name is visited in a preorder traversal
-- of the PHOAS expression, excluding children of nodes upon second or later
-- visit. That is to say: only the nodes that are visited in a preorder
-- traversal that skips repeated subtrees, are counted.
type OccMap typ f = HashMap (SomeNameFor typ f) Natural

pruneExpr :: Traversable1 f => (forall v. PHOASExpr typ v f t) -> (OccMap typ f, PExpr typ f t)
pruneExpr term =
  let (term', (_, mp)) = runState (pruneExpr' term) (0, mempty)
  in (mp, term')

pruneExpr' :: Traversable1 f => PHOASExpr typ Tag f t -> State (Natural, OccMap typ f) (PExpr typ f t)
pruneExpr' = \case
  orig@(PHOASOp ty args) -> do
    let name = makeStableName' orig
    seenBefore <- checkVisited name
    if seenBefore
      then pure $ PStub (NameFor name) ty
      else POp (NameFor name) ty <$> traverse1 pruneExpr' args

  orig@(PHOASLam tyf tyarg f) -> do
    let name = makeStableName' orig
    seenBefore <- checkVisited name
    if seenBefore
      then pure $ PStub (NameFor name) tyf
      else do
        tag <- state (\(i, mp) -> (Tag i, (i + 1, mp)))
        let body = f tag
        PLam (NameFor name) tyf tyarg tag <$> pruneExpr' body

  PHOASVar ty tag -> pure $ PVar ty tag
  where
    checkVisited name = do
      occmap <- gets snd
      let (seenBefore, occmap') =
            HM.alterF (\case Nothing -> (False, Just 1)
                             Just n -> (True, Just (n + 1)))
                      (SomeNameFor (NameFor name))
                      occmap
      modify (second (const occmap'))
      return seenBefore


-- | Lifted expression: a bunch of to-be let bound expressions on top of an
-- LExpr'. Because LExpr' is really just PExpr with the recursive positions
-- replaced by LExpr, LExpr should be seen as PExpr with a bunch of to-be let
-- bound expressions on top of every node.
data LExpr typ f t = LExpr [Some (LExpr typ f)] (LExpr' typ f t)
data LExpr' typ f t where  -- TODO: this could be an instantiation of (a generalisation of) PExpr
  LStub :: NameFor typ f t -> typ t -> LExpr' typ f t
  LOp :: NameFor typ f t -> typ t -> f (LExpr typ f) t -> LExpr' typ f t
  LLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> LExpr typ f b -> LExpr' typ f (a -> b)
  LVar :: typ a -> Tag a -> LExpr' typ f a

liftExpr :: Traversable1 f => OccMap typ f -> PExpr typ f t -> LExpr typ f t
liftExpr totals term = snd (liftExpr' totals term)

newtype FoundMap typ f = FoundMap
  (HashMap (SomeNameFor typ f) (Natural  -- how many times seen
                               ,Maybe (Some (LExpr typ f))))  -- the lifted subterm (once seen)

instance Semigroup (FoundMap typ f) where
  FoundMap m1 <> FoundMap m2 = FoundMap $
    HM.unionWith (\(n1, me1) (n2, me2) -> (n1 + n2, me1 <|> me2)) m1 m2

instance Monoid (FoundMap typ f) where
  mempty = FoundMap HM.empty

liftExpr' :: Traversable1 f => OccMap typ f -> PExpr typ f t -> (FoundMap typ f, LExpr typ f t)
liftExpr' _totals (PStub name ty) =
  (FoundMap $ HM.singleton (SomeNameFor name) (1, Nothing)  -- Just (Some (LExpr [] (LStub name)))
  ,LExpr [] (LStub name ty))

liftExpr' _totals (PVar ty tag) = (mempty, LExpr [] (LVar ty tag))

liftExpr' totals term =
  let (FoundMap foundmap, name, termty, term') = case term of
        POp n ty args ->
          let (fm, args') = traverse1 (liftExpr' totals) args
          in (fm, n, ty, LOp n ty args')
        PLam n tyf tyarg tag body ->
          let (fm, body') = liftExpr' totals body
          in (fm, n, tyf, LLam n tyf tyarg tag body')

      -- TODO: perhaps this HM.toList together with the foldr HM.delete can be a single traversal of the HashMap
      saturated = [case mterm of
                     Just t -> (nm, t)
                     Nothing -> error "Name saturated but no term found"
                  | (nm, (count, mterm)) <- HM.toList foundmap
                  , count == HM.findWithDefault 0 nm totals]

      foundmap' = foldr HM.delete foundmap (map fst saturated)

      lterm = LExpr (map snd saturated) term'

  in case HM.findWithDefault 0 (SomeNameFor name) totals of
       1 -> (FoundMap foundmap', lterm)
       tot | tot > 1 -> (FoundMap (HM.insert (SomeNameFor name) (1, Just (Some lterm)) foundmap')
                        ,LExpr [] (LStub name termty))
           | otherwise -> error "Term does not exist, yet we have it in hand"


-- | Untyped De Bruijn expression. No more names: there are lets now, and
-- variable references are De Bruijn indices. These indices are not type-safe
-- yet, though.
data UBExpr typ f t where
  UBOp :: typ t -> f (UBExpr typ f) t -> UBExpr typ f t
  UBLam :: typ (a -> b) -> typ a -> UBExpr typ f b -> UBExpr typ f (a -> b)
  UBLet :: typ a -> UBExpr typ f a -> UBExpr typ f b -> UBExpr typ f b
  -- | De Bruijn index
  UBVar :: typ t -> Int -> UBExpr typ f t

lowerExpr :: Functor1 f => LExpr typ f t -> UBExpr typ f t
lowerExpr = lowerExpr' mempty mempty 0

data SomeTag = forall t. SomeTag (Tag t)

instance Eq SomeTag where
  SomeTag (Tag n) == SomeTag (Tag m) = n == m

instance Hashable SomeTag where
  hashWithSalt salt (SomeTag tag) = hashWithSalt salt tag

lowerExpr' :: forall typ f t. Functor1 f
           => HashMap (SomeNameFor typ f) Int  -- ^ node |-> De Bruijn level of defining binding
           -> HashMap SomeTag Int  -- ^ tag |-> De Bruijn level of defining binding
           -> Int  -- ^ Number of variables already in scope
           -> LExpr typ f t -> UBExpr typ f t
lowerExpr' namelvl taglvl curlvl (LExpr lifted ex) =
  let (namelvl', prefix) = buildPrefix namelvl curlvl lifted
      curlvl' = curlvl + length lifted
  in prefix $
       case ex of
         LStub name ty ->
           case HM.lookup (SomeNameFor name) namelvl' of
             Just lvl -> UBVar ty (curlvl - lvl - 1)
             Nothing -> error "Name variable out of scope"
         LOp _ ty args ->
           UBOp ty (fmap1 (lowerExpr' namelvl' taglvl curlvl') args)
         LLam _ tyf tyarg tag body ->
           UBLam tyf tyarg (lowerExpr' namelvl' (HM.insert (SomeTag tag) curlvl' taglvl) (curlvl' + 1) body)
         LVar ty tag ->
           case HM.lookup (SomeTag tag) taglvl of
             Just lvl -> UBVar ty (curlvl - lvl - 1)
             Nothing -> error "Tag variable out of scope"
  where
    buildPrefix :: forall b.
                   HashMap (SomeNameFor typ f) Int
                -> Int
                -> [Some (LExpr typ f)]
                -> (HashMap (SomeNameFor typ f) Int, UBExpr typ f b -> UBExpr typ f b)
    buildPrefix namelvl' _ [] = (namelvl', id)
    buildPrefix namelvl' lvl (Some rhs@(LExpr _ rhs') : rhss) =
      let name = case rhs' of
                   LStub n _ -> n
                   LOp n _ _ -> n
                   LLam n _ _ _ _ -> n
                   LVar _ _ -> error "Recovering sharing of a tag is useless"
          ty = case rhs' of
                 LStub{} -> error "Recovering sharing of a stub is useless"
                 LOp _ t _ -> t
                 LLam _ t _ _ _ -> t
                 LVar t _ -> t
          prefix = UBLet ty (lowerExpr' namelvl' taglvl lvl rhs)
      in (prefix .) <$> buildPrefix (HM.insert (SomeNameFor name) lvl namelvl') (lvl + 1) rhss


-- | A typed De Bruijn index.
data Idx env t where
  IZ :: Idx (t : env) t
  IS :: Idx env t -> Idx (s : env) t
deriving instance Show (Idx env t)

data Env env f where
  ETop :: Env '[] f
  EPush :: Env env f -> f t -> Env (t : env) f

envLookup :: Idx env t -> Env env f -> f t
envLookup IZ (EPush _ x) = x
envLookup (IS i) (EPush e _) = envLookup i e

-- | Untyped lookup in an 'Env'.
envLookupU :: Int -> Env env f -> Maybe (Some (Product f (Idx env)))
envLookupU = go id
  where
    go :: (forall a. Idx env a -> Idx env' a) -> Int -> Env env f -> Maybe (Some (Product f (Idx env')))
    go !_ !_ ETop = Nothing
    go f 0 (EPush _ t) = Just (Some (Pair t (f IZ)))
    go f i (EPush e _) = go (f . IS) (i - 1) e

-- | Typed De Bruijn expression. This is the final result of sharing recovery.
data BExpr typ env f t where
  BOp :: typ t -> f (BExpr typ env f) t -> BExpr typ env f t
  BLam :: typ (a -> b) -> typ a -> BExpr typ (a : env) f b -> BExpr typ env f (a -> b)
  BLet :: typ a -> BExpr typ env f a -> BExpr typ (a : env) f b -> BExpr typ env f b
  BVar :: typ t -> Idx env t -> BExpr typ env f t
deriving instance (forall a. Show (typ a), forall a r. (forall b. Show (r b)) => Show (f r a))
               => Show (BExpr typ env f t)

prettyBExpr :: (forall m env' a. Monad m => (forall b. Int -> BExpr typ env' f b -> m ShowS)
                                         -> Int -> f (BExpr typ env' f) a -> m ShowS)
            -> BExpr typ '[] f t -> String
prettyBExpr prettyOp e = evalState (prettyBExpr' prettyOp ETop 0 e) 0 ""

prettyBExpr' :: (forall m env' a. Monad m => (forall b. Int -> BExpr typ env' f b -> m ShowS)
                                          -> Int -> f (BExpr typ env' f) a -> m ShowS)
             -> Env env (Const String) -> Int -> BExpr typ env f t -> State Int ShowS
prettyBExpr' prettyOp env d = \case
  BOp _ args ->
    prettyOp (prettyBExpr' prettyOp env) d args
  BLam _ _ body -> do
    name <- genName
    body' <- prettyBExpr' prettyOp (EPush env (Const name)) 0 body
    return $ showParen (d > 0) $ showString ("λ" ++ name ++ ". ") . body'
  BLet _ rhs body -> do
    name <- genName
    rhs' <- prettyBExpr' prettyOp env 0 rhs
    body' <- prettyBExpr' prettyOp (EPush env (Const name)) 0 body
    return $ showParen (d > 0) $ showString ("let " ++ name ++ " = ") . rhs' . showString " in " . body'
  BVar _ idx ->
    return $ showString (getConst (envLookup idx env))
  where
    genName = do
      i <- state (\i -> (i, i + 1))
      return $ if i < 26 then [chr (ord 'a' + i)] else 'x' : show i

retypeExpr :: (Functor1 f, TestEquality typ) => UBExpr typ f t -> BExpr typ '[] f t
retypeExpr = retypeExpr' ETop

retypeExpr' :: (Functor1 f, TestEquality typ) => Env env typ -> UBExpr typ f t -> BExpr typ env f t
retypeExpr' env (UBOp ty args) = BOp ty (fmap1 (retypeExpr' env) args)
retypeExpr' env (UBLam tyf tyarg body) = BLam tyf tyarg (retypeExpr' (EPush env tyarg) body)
retypeExpr' env (UBLet ty rhs body) = BLet ty (retypeExpr' env rhs) (retypeExpr' (EPush env ty) body)
retypeExpr' env (UBVar ty idx) =
  case envLookupU idx env of
    Just (Some (Pair defty tidx)) ->
      case testEquality ty defty of
        Just Refl -> BVar ty tidx
        Nothing -> error "Type mismatch in untyped De Bruijn expression"
    Nothing -> error "Untyped De Bruijn index out of range"


sharingRecovery :: (Traversable1 f, TestEquality typ) => (forall v. PHOASExpr typ v f t) -> BExpr typ '[] f t
sharingRecovery e = retypeExpr $ lowerExpr $ uncurry liftExpr $ pruneExpr e