aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Expr/SharingRecovery.hs
blob: 118df1c431f860514933ff785dda94e481a26955 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeOperators #-}
module Data.Expr.SharingRecovery where

import Control.Applicative ((<|>))
import Control.Monad.Trans.State.Strict
import Data.Bifunctor (second)
import Data.GADT.Compare
import Data.Hashable
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HM
import Data.Some
import Data.Type.Equality
import GHC.StableName
import Numeric.Natural
import Unsafe.Coerce (unsafeCoerce)

import Data.StableName.Extra


class Functor1 f where
  fmap1 :: (forall b. g b -> h b) -> f g a -> f h a

class Functor1 f => Traversable1 f where
  traverse1 :: Applicative m => (forall b. g b -> m (h b)) -> f g a -> m (f h a)

-- | Expression in parametric higher-order abstract syntax form
data PHOASExpr v f t where
  PHOASOp :: f (PHOASExpr v f) t -> PHOASExpr v f t
  PHOASLam :: (PHOASExpr v f a -> PHOASExpr v f b) -> PHOASExpr v f (a -> b)
  PHOASVar :: v t -> PHOASExpr v f t

newtype Tag t = Tag Natural
  deriving (Show, Eq)

newtype NameFor f t = NameFor (StableName (PHOASExpr Tag f t))
  deriving (Eq)
  deriving (Hashable) via (StableName (f (PHOASExpr Tag f) t))

instance GEq (NameFor f) where
  geq (NameFor n1) (NameFor n2)
    | eqStableName n1 n2 = Just unsafeCoerceRefl
    | otherwise = Nothing
    where
      unsafeCoerceRefl :: a :~: b  -- restricted version of unsafeCoerce that only allows punting proofs
      unsafeCoerceRefl = unsafeCoerce Refl

-- | Pruned expression
data PExpr f t where
  PStub :: NameFor f t -> PExpr f t
  POp :: NameFor f t -> f (PExpr f) t -> PExpr f t
  PLam :: NameFor f (a -> b) -> Tag a -> PExpr f b -> PExpr f (a -> b)
  PVar :: Tag a -> PExpr f a

data SomeNameFor f = forall t. SomeNameFor {-# UNPACK #-} !(NameFor f t)

instance Eq (SomeNameFor f) where
  SomeNameFor (NameFor n1) == SomeNameFor (NameFor n2) = eqStableName n1 n2

instance Hashable (SomeNameFor f) where
  hashWithSalt salt (SomeNameFor name) = hashWithSalt salt name

type OccMap f = HashMap (SomeNameFor f) Natural

pruneExpr :: Traversable1 f => (forall v. PHOASExpr v f t) -> (OccMap f, PExpr f t)
pruneExpr term =
  let (term', (_, mp)) = runState (pruneExpr' term) (0, mempty)
  in (mp, term')

pruneExpr' :: Traversable1 f => PHOASExpr Tag f t -> State (Natural, OccMap f) (PExpr f t)
pruneExpr' orig@(PHOASOp args) = do
  let name = makeStableName' orig
  occmap <- gets snd
  let (seenBefore, occmap') =
        HM.alterF (\case Nothing -> (False, Just 1)
                         Just n -> (True, Just (n + 1)))
                  (SomeNameFor (NameFor name))
                  occmap
  modify (second (const occmap'))
  if seenBefore
    then pure $ PStub (NameFor name)
    else POp (NameFor name) <$> traverse1 pruneExpr' args

pruneExpr' orig@(PHOASLam f) = do
  let name = makeStableName' orig
  tag <- state (\(i, mp) -> (Tag i, (i + 1, mp)))
  let body = f (PHOASVar tag)
  PLam (NameFor name) tag <$> pruneExpr' body

pruneExpr' (PHOASVar tag) = pure $ PVar tag


-- | Lifted expression: a bunch of to-be let bound expressions on top of an LExpr'
data LExpr f t = LExpr [Some (LExpr f)] (LExpr' f t)
data LExpr' f t where  -- TODO: this could be an instantiation of (a generalisation of) PExpr
  LStub :: NameFor f t -> LExpr' f t
  LOp :: NameFor f t -> f (LExpr f) t -> LExpr' f t
  LLam :: NameFor f (a -> b) -> Tag a -> LExpr f b -> LExpr' f (a -> b)
  LVar :: Tag a -> LExpr' f a

liftExpr :: Traversable1 f => OccMap f -> PExpr f t -> LExpr f t
liftExpr totals term =
  let (_, e) = liftExpr' totals term
  in e

newtype FoundMap f = FoundMap
  (HashMap (SomeNameFor f) (Natural  -- how many times seen
                           ,Maybe (Some (LExpr f))))  -- the lifted subterm (once seen)

instance Semigroup (FoundMap f) where
  FoundMap m1 <> FoundMap m2 = FoundMap $
    HM.unionWith (\(n1, me1) (n2, me2) -> (n1 + n2, me1 <|> me2)) m1 m2

instance Monoid (FoundMap f) where
  mempty = FoundMap HM.empty

liftExpr' :: Traversable1 f => OccMap f -> PExpr f t -> (FoundMap f, LExpr f t)
liftExpr' _totals (PStub name) =
  (FoundMap $ HM.singleton (SomeNameFor name) (1, Just (Some (LExpr [] (LStub name))))
  ,LExpr [] (LStub name))

liftExpr' _totals (PVar tag) = (mempty, LExpr [] (LVar tag))

liftExpr' totals term =
  let (FoundMap foundmap, name, term') = case term of
        POp n args -> let (fm, args') = traverse1 (liftExpr' totals) args
                      in (fm, n, LOp n args')
        PLam n tag body -> let (fm, body') = liftExpr' totals body
                           in (fm, n, LLam n tag body')

      saturated = [case mterm of
                     Just t -> (nm, t)
                     Nothing -> error "Name saturated but no term found"
                  | (nm, (count, mterm)) <- HM.toList foundmap
                  , count == HM.findWithDefault 0 nm totals]

      foundmap' = foldr HM.delete foundmap (map fst saturated)

      lterm = LExpr (map snd saturated) term'

  in case HM.findWithDefault 0 (SomeNameFor name) totals of
       1 -> (FoundMap foundmap', lterm)
       tot | tot > 1 -> (FoundMap (HM.insert (SomeNameFor name) (1, Just (Some lterm)) foundmap')
                        ,LExpr [] (LStub name))
           | otherwise -> error "Term does not exist, yet we have it in hand"


-- TODO: lower LExpr into a normal expression with let bindings. Every LStub
-- should correspond to some let-bound expression higher up in the tree (if it
-- does not, that's a bug), and should become a De Bruijn variable reference to
-- said let-bound expression. Lambdas should also get proper De Bruijn indices
-- instead of tags, and LVar is also a normal variable (referring to a
-- lambda-abstracted argument).