aboutsummaryrefslogtreecommitdiff
path: root/Simplify.hs
blob: e95043da87dab038dcb5aa42ad3816c0eeb8b51c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Simplify (
    simplify,
    simplifyFix,
    simbeta, simpair, simindex, simifold1,
    simfix, SimList(..),
) where

import Data.Bifunctor
import Data.GADT.Compare
import qualified Data.Kind as Kind
import Data.List (find)
import Data.Type.Equality

import Debug.Trace

import AST
import Count
import Sink


data Bound a = Inclusive a | Exclusive a
  deriving (Show)

instance Functor Bound where
    fmap f (Inclusive x) = Inclusive (f x)
    fmap f (Exclusive x) = Exclusive (f x)

data family Info (env :: [Kind.Type]) a
-- | Lower bound (inclusive), upper bound (inclusive/exclusive)
data instance Info env Int = InfoInt (Maybe (Exp env Int)) (Maybe (Bound (Exp env Int)))
data instance Info env (Array sh t) = InfoArray (Exp env sh)
data instance Info env () = InfoNil
data instance Info env (a, b) = InfoPair (Maybe (Info env a)) (Maybe (Info env b))

data IEnv env where
    ITop :: IEnv env
    ICons :: Type a -> Maybe (Info (a ': env) a) -> IEnv env -> IEnv (a ': env)

showsInfo :: Int -> Type a -> Info env a -> ShowS
showsInfo d TInt (InfoInt a b) = showParen (d > 10) $
    showString "InfoInt " . showsPrec 11 a . showString " " . showsPrec 11 b
showsInfo d TArray{} (InfoArray a) = showParen (d > 10) $
    showString "InfoArray " . showsPrec 11 a
showsInfo _ TNil InfoNil = showString "InfoNil"
showsInfo d (TPair t1 t2) (InfoPair a b) = showParen (d > 10) $
    showString "InfoPair " . showsInfo' 11 t1 a . showString " " . showsInfo' 11 t2 b
showsInfo _ _ _ = error "showsInfo: No definition"

showsInfo' :: Int -> Type a -> Maybe (Info env a) -> ShowS
showsInfo' _ _ Nothing = showString "Nothing"
showsInfo' d t (Just x) = showParen (d > 10) $
    showString "Just " . showsInfo 11 t x

sinkInfo1 :: Type a -> Info env a -> Info (t ': env) a
sinkInfo1 TInt (InfoInt a b) = InfoInt (sinkExp1 <$> a) (fmap sinkExp1 <$> b)
sinkInfo1 TArray{} (InfoArray e) = InfoArray (sinkExp1 e)
sinkInfo1 TNil InfoNil = InfoNil
sinkInfo1 (TPair t1 t2) (InfoPair a b) = InfoPair (sinkInfo1 t1 <$> a) (sinkInfo1 t2 <$> b)
sinkInfo1 _ _ = error "Unknown info in sinkInfo1"

iprj :: IEnv env -> Idx env a -> Maybe (Type a, Info env a)
iprj ITop _ = Nothing
iprj (ICons t m _) Zero = (t,) <$> m
iprj (ICons _ _ env) (Succ i) = (\(t, m) -> (t, sinkInfo1 t m)) <$> iprj env i

simplifyFix :: Exp env a -> Exp env a
simplifyFix e =
    let maxTimes = 4
        es = take (maxTimes + 1) (iterate simplify e)
        pairs = zip es (tail es)
    in case find (\(a,b) -> case geq a b of Just Refl -> True ; _ -> False) pairs of
         Just (e', _) -> e'
         Nothing -> error "Simplification doesn't converge!"

simplify :: Exp env a -> Exp env a
simplify = fst . simplify' ITop

simplify' :: IEnv env -> Exp env a -> (Exp env a, Maybe (Info env a))
simplify' env = \case
    App a b -> (simplifyApp (fst (simplify' env a)) (fst (simplify' env b)), Nothing)
    Lam t e -> (Lam t (fst (simplify' (ICons t Nothing env) e)), Nothing)
    Var t i -> (Var t i, snd <$> iprj env i)
    Let arg e ->
        let (arg', info) = simplify' env arg
            env' = ICons (typeof arg) (sinkInfo1 (typeof arg) <$> info) env
        in (simplifyLet arg' (fst (simplify' env' e)), Nothing)
    Lit (LInt n) -> (Lit (LInt n), Just (InfoInt (Just (Lit (LInt n)))
                                                 (Just (Inclusive (Lit (LInt n))))))
    Lit l -> (Lit l, Nothing)
    Cond a b c ->
        (Cond (fst (simplify' env a)) (fst (simplify' env b)) (fst (simplify' env c)), Nothing)
    Const c -> (Const c, Nothing)
    Pair a b ->
        let (a', ia) = simplify' env a
            (b', ib) = simplify' env b
        in (simplifyPair a' b', Just (InfoPair ia ib))
    Fst e -> bimap simplifyFst (>>= (\(InfoPair i _) -> i)) (simplify' env e)
    Snd e -> bimap simplifySnd (>>= (\(InfoPair _ i) -> i)) (simplify' env e)
    Build sht a (Lam shty fe) ->
        let a' = fst (simplify' env a)
            env' = ICons shty (Just (shapeBoundInfo sht (sinkExp1 a'))) env
        in (Build sht a' (Lam shty (fst (simplify' env' fe))), Just (InfoArray a'))
    Build sht a b ->
        let a' = fst (simplify' env a)
        in (Build sht a' (fst (simplify' env b)), Just (InfoArray a'))
    Ifold sht a b c ->
        (simplifyIfold env sht (fst (simplify' env a)) (fst (simplify' env b)) (fst (simplify' env c)), Nothing)
    Index a b -> (simplifyIndex (fst (simplify' env a)) (fst (simplify' env b)), Nothing)
    Shape e ->
        case simplify' env e of
          (_, Just (InfoArray she)) -> (she, Nothing)
          (e', _) -> (Shape e', Nothing)
    Undef t -> (Undef t, Nothing)

shapeBoundInfo :: ShapeType sh -> Exp env sh -> Info env sh
shapeBoundInfo STZ _ = InfoNil
shapeBoundInfo (STC sht) she =
    InfoPair (Just (shapeBoundInfo sht (Fst she)))
             (Just (InfoInt (Just (Lit (LInt 0))) (Just (Exclusive (Snd she)))))

simplifyApp :: Exp env (a -> b) -> Exp env a -> Exp env b
simplifyApp (Const CAddI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a + b))
simplifyApp (Const CAddI) (Pair a (Lit (LInt 0))) = a
simplifyApp (Const CAddI) (Pair (Lit (LInt 0)) a) = a
-- simplifyApp (Const CAddI) (Pair a b) | Just Refl <- geq a b =
--     simplifyApp (Const CMulI) (Pair (Lit (LInt 2)) a)
simplifyApp (Const CSubI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a - b))
simplifyApp (Const CMulI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a * b))
simplifyApp (Const CMulI) (Pair a (Lit (LInt 1))) = a
simplifyApp (Const CMulI) (Pair (Lit (LInt 1)) a) = a
simplifyApp (Const CMulI) (Pair _ (Lit (LInt 0))) = Lit (LInt 0)
simplifyApp (Const CMulI) (Pair (Lit (LInt 0)) _) = Lit (LInt 0)
simplifyApp (Const CDivI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a `div` b))
simplifyApp (Const CAddF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a + b))
simplifyApp (Const CAddF) (Pair a (Lit (LDouble 0))) = a
simplifyApp (Const CAddF) (Pair (Lit (LDouble 0)) a) = a
-- simplifyApp (Const CAddF) (Pair a b) | Just Refl <- geq a b =
--     simplifyApp (Const CMulF) (Pair (Lit (LDouble 2)) a)
simplifyApp (Const CSubF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a - b))
simplifyApp (Const CMulF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a * b))
simplifyApp (Const CMulF) (Pair a (Lit (LDouble 1))) = a
simplifyApp (Const CMulF) (Pair (Lit (LDouble 1)) a) = a
simplifyApp (Const CMulF) (Pair _ (Lit (LDouble 0))) = Lit (LDouble 0)
simplifyApp (Const CMulF) (Pair (Lit (LDouble 0)) _) = Lit (LDouble 0)
simplifyApp (Const CDivF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a / b))
simplifyApp (Const CLog) (Lit (LDouble a)) = Lit (LDouble (log a))
simplifyApp (Const CExp) (Lit (LDouble a)) = Lit (LDouble (exp a))
simplifyApp (Const CtoF) (Lit (LInt a)) = Lit (LDouble (fromIntegral a))
simplifyApp (Const CRound) (Lit (LDouble a)) = Lit (LInt (round a))
simplifyApp (Const CLtI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LBool (a < b))
simplifyApp (Const CLtF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LBool (a < b))
simplifyApp (Const (CEq _)) (Pair a b)
  | Just Refl <- geq a b
  = Lit (LBool True)
simplifyApp (Const CAnd) (Pair (Lit (LBool a)) (Lit (LBool b))) = Lit (LBool (a && b))
simplifyApp (Const COr) (Pair (Lit (LBool a)) (Lit (LBool b))) = Lit (LBool (a || b))
simplifyApp (Const CNot) (Lit (LBool a)) = Lit (LBool (not a))

simplifyApp (Lam _ e) arg
  | isDuplicable arg || usesOf Zero e <= 1
  = simplify (subst arg e)
simplifyApp (Lam _ e) arg = simplifyLet arg e

simplifyApp f (Cond c a b) = simplify $ Cond c (App f a) (App f b)

simplifyApp a b = App a b

simplifyLet :: Exp env a -> Exp (a ': env) b -> Exp env b
simplifyLet arg e
  | isDuplicable arg || usesOf Zero e <= 1
  = simplify (subst arg e)
simplifyLet (Pair a b) e =
    simplifyLet a $
        simplifyLet (sinkExp1 b) $
            subst' (\t -> \case Zero -> Pair (Var (typeof a) (Succ Zero))
                                             (Var (typeof b) Zero)
                                Succ i -> Var t (Succ (Succ i)))
                   e
-- simplifyLet (Cond c a b) e
--   | isDuplicable a && isDuplicable b
--   = simplifyLet c $
--       (subst' (\t -> \case Zero -> Cond (Var TBool Zero) (sinkExp1 a) (sinkExp1 b)
--                            Succ i -> Var t (Succ i))
--               e)
simplifyLet (Cond c a b) e = simplify $ Cond c (Let a e) (Let b e)
simplifyLet a b = Let a b

simplifyPair :: Exp env a -> Exp env b -> Exp env (a, b)
simplifyPair (Cond c a b) d = simplify $ Cond c (Pair a d) (Pair b d)
simplifyPair d (Cond c a b) = simplify $ Cond c (Pair d a) (Pair d b)
simplifyPair a b = Pair a b

simplifyFst :: Exp env (a, b) -> Exp env a
simplifyFst (Pair e _) = e
simplifyFst (Let a e) = simplifyLet a (simplifyFst e)
simplifyFst (Cond c a b) = simplify $ Cond c (Fst a) (Fst b)
simplifyFst e = Fst e

simplifySnd :: Exp env (a, b) -> Exp env b
simplifySnd (Pair _ e) = e
simplifySnd (Let a e) = simplifyLet a (simplifySnd e)
simplifySnd (Cond c a b) = simplify $ Cond c (Snd a) (Snd b)
simplifySnd e = Snd e

simplifyIfold :: IEnv env -> ShapeType sh -> Exp env ((a, sh) -> a) -> Exp env a -> Exp env sh -> Exp env a
simplifyIfold env sht fe e0 she
  | Just res <- splitIfold sht fe e0 she
  = fst (simplify' env res)
-- Given the following:
--   ifold (\(a,i) -> if i == cmpref then val else a) _ she
-- and given that we can prove that 0 <= cmpref < she and that 'val' is free,
-- the whole fold can be replaced with 'val'.
simplifyIfold env sht (Lam argty (Cond (App (Const (CEq _)) (Pair (Snd (Var _ Zero)) cmpref)) val (Fst (Var _ Zero)))) e0 she
  | let env' = ICons argty (Just (InfoPair Nothing (Just (shapeBoundInfo sht (sinkExp1 she))))) env
  , trace ("si: trying") True
  , proveShapeBound env' CLeI sht (zeroShapeExp sht) cmpref
  , trace ("si: prf1 = True") True
  , proveShapeBound env' CLtI sht cmpref (sinkExp1 she)
  , trace ("si: prf2 = True") True
  , trace ("si: cmpref = " ++ show val) True
  , usesOf Zero cmpref == 0
  = simplifyLet (Pair e0 (subst (error "usesOf == 0 was wrong") cmpref)) val
simplifyIfold _ sht fe e0 she = Ifold sht fe e0 she

zeroShapeExp :: ShapeType sh -> Exp env sh
zeroShapeExp STZ = Lit LNil
zeroShapeExp (STC sht) = Pair (zeroShapeExp sht) (Lit (LInt 0))

proveShapeBound :: IEnv env -> Constant ((Int, Int) -> Bool) -> ShapeType sh -> Exp env sh -> Exp env sh -> Bool
proveShapeBound _ _ STZ _ _ = True
proveShapeBound env cmpop (STC sht) e1 e2 =
    let (_, info1) = simplify' env (Snd e1)
        (_, info2) = simplify' env (Snd e2)
        inclLo2 = case info2 of
                    Just (InfoInt (Just lo2) _) -> lo2
                    _ -> Snd e2  -- this is also an inclusive lower bound, after all
        restresult = proveShapeBound env cmpop sht (Fst e1) (Fst e2)
    in restresult && case (cmpop, info1) of
         (CLeI, Just (InfoInt _ (Just (Inclusive hi1)))) ->
             proveLe hi1 inclLo2
         (CLtI, Just (InfoInt _ (Just (Exclusive hi1)))) ->
             proveLe hi1 inclLo2
         _ -> trace ("proveShapeBound: " ++ show cmpop ++ " (" ++ show e1 ++ ") (" ++ show e2 ++ ")") $
              trace ("  e1 = " ++ show e1) $
              trace ("  e2 = " ++ show e2) $
              trace ("  info1 = " ++ showsInfo' 0 TInt info1 "") $
              trace ("  info2 = " ++ showsInfo' 0 TInt info2 "") $
              trace ("  inclLo2 = " ++ show inclLo2) $
              False

proveLe :: Exp env Int -> Exp env Int -> Bool
proveLe = \e1 e2 ->
    let res = proveLe' e1 e2
    in trace ("proveLe: '" ++ show e1 ++ "'  <=  '" ++ show e2 ++ "'  ->  " ++ show res)
             res

proveLe' :: Exp env Int -> Exp env Int -> Bool
proveLe' e1 e2 | Just Refl <- geq e1 e2 = True
proveLe' (Lit (LInt a)) (Lit (LInt b)) | a <= b = True
proveLe' _ _ = False

simplifyIndex :: Exp env (Array sh a) -> Exp env sh -> Exp env a
simplifyIndex (Build _ _ f) e = simplifyApp f e
simplifyIndex a e = Index a e

isDuplicable :: Exp env a -> Bool
isDuplicable (Lam _ e) = isDuplicable e
isDuplicable (Var _ _) = True
isDuplicable (Let a e) = isDuplicable a && isDuplicable e
isDuplicable (Lit (LInt _)) = True
isDuplicable (Lit (LBool _)) = True
isDuplicable (Lit (LDouble _)) = True
isDuplicable (Lit (LShape _)) = True
isDuplicable (Lit LNil) = True
isDuplicable (Lit (LPair l1 l2)) = isDuplicable (Lit l1) && isDuplicable (Lit l2)
isDuplicable (Const _) = True
isDuplicable (Pair a b) = isDuplicable a && isDuplicable b
isDuplicable (Fst e) = isDuplicable e
isDuplicable (Snd e) = isDuplicable e
isDuplicable _ = False

subst :: Exp env t -> Exp (t ': env) a -> Exp env a
subst arg e = subst' (\t -> \case Zero -> arg ; Succ i -> Var t i) e

subst' :: (forall t. Type t -> Idx env t -> Exp env' t) -> Exp env a -> Exp env' a
subst' f (App a b) = App (subst' f a) (subst' f b)
subst' f (Lam t e) =
    Lam t (subst' (\t' -> \case Zero -> Var t' Zero ; Succ i -> sinkExp1 (f t' i)) e)
subst' f (Var t i) = f t i
subst' f (Let a b) =
    Let (subst' f a)
        (subst' (\t -> \case Zero -> Var t Zero ; Succ i -> sinkExp1 (f t i)) b)
subst' _ (Lit l) = Lit l
subst' f (Cond a b c) = Cond (subst' f a) (subst' f b) (subst' f c)
subst' _ (Const c) = Const c
subst' f (Pair a b) = Pair (subst' f a) (subst' f b)
subst' f (Fst e) = Fst (subst' f e)
subst' f (Snd e) = Snd (subst' f e)
subst' f (Build sht a b) = Build sht (subst' f a) (subst' f b)
subst' f (Ifold sht a b c) = Ifold sht (subst' f a) (subst' f b) (subst' f c)
subst' f (Index a b) = Index (subst' f a) (subst' f b)
subst' f (Shape e) = Shape (subst' f e)
subst' _ (Undef t) = Undef t

splitIfold :: ShapeType sh -> Exp env ((s, sh) -> s) -> Exp env s -> Exp env sh -> Maybe (Exp env s)
splitIfold sht (Lam (TPair (TPair t1 t2) tidx) (Pair e1 e2)) e0 she
  | let uses1 = usesOf' PathStart Zero e1
        uses2 = usesOf' PathStart Zero e2
  , lycontract (lysnd (lyfst uses1)) == 0
  , lycontract (lyfst (lyfst uses2)) == 0
    -- Substitute the argument in e1 and t2 to refer to just the used
    -- components of their argument. To do this we reconstruct the original,
    -- partially unused argument by putting an 'Undef' in the unused spot.
  , let e1' = subst' (\t -> \case Zero ->
                                    Pair (Pair (Fst (Var (TPair t1 tidx) Zero))
                                               (Undef t2))
                                         (Snd (Var (TPair t1 tidx) Zero))
                                  Succ i -> Var t (Succ i))
                     e1
        e2' = subst' (\t -> \case Zero ->
                                    Pair (Pair (Undef t1)
                                               (Fst (Var (TPair t2 tidx) Zero)))
                                         (Snd (Var (TPair t2 tidx) Zero))
                                  Succ i -> Var t (Succ i))
                     e2
  = Just $
      Let e0 $ Let (sinkExp1 she) $
        Pair (Ifold sht (sinkExp2 (Lam (TPair t1 tidx) e1'))
                        (Fst (Var (TPair t1 t2) (Succ Zero)))
                        (Var tidx Zero))
             (Ifold sht (sinkExp2 (Lam (TPair t2 tidx) e2'))
                        (Snd (Var (TPair t1 t2) (Succ Zero)))
                        (Var tidx Zero))
splitIfold _ _ _ _ = Nothing

simbeta :: Exp env a -> Exp env a
simbeta = \case
    App (Lam _ e) a
      | isDuplicable a || usesOf Zero e <= 1
      -> simbeta (subst a e)
      | otherwise
      -> Let (simbeta a) (simbeta e)
    Let a e
      | isDuplicable a || usesOf Zero e <= 1
      -> simbeta (subst a e)
    e -> simrecurse simbeta e

simpair :: Exp env a -> Exp env a
simpair = \case
    Fst (Pair a _) -> simpair a
    Snd (Pair _ b) -> simpair b
    Fst (Let a b) -> Let (simpair a) (simpair (Fst b))
    Snd (Let a b) -> Let (simpair a) (simpair (Snd b))
    e -> simrecurse simpair e

simindex :: Exp env a -> Exp env a
simindex = \case
    Index (Build _ _ f) e ->
        App (simindex f) (simindex e)
    e -> simrecurse simindex e

simifold1 :: Exp env a -> Exp env a
simifold1 = \case
    Ifold sht fe e0 she
      | Just res <- splitIfold sht (simifold1 fe) (simifold1 e0) (simifold1 she)
      -> res
    e -> simrecurse simifold1 e

simrecurse :: (forall env' a'. Exp env' a' -> Exp env' a') -> Exp env a -> Exp env a
simrecurse f = \case
    App a b -> App (f a) (f b)
    Lam t e -> Lam t (f e)
    Var t i -> Var t i
    Let a e -> Let (f a) (f e)
    Lit l -> Lit l
    Cond a b c -> Cond (f a) (f b) (f c)
    Const c -> Const c
    Pair a b -> Pair (f a) (f b)
    Fst e -> Fst (f e)
    Snd e -> Snd (f e)
    Build sht a b -> Build sht (f a) (f b)
    Ifold sht a b c -> Ifold sht (f a) (f b) (f c)
    Index a b -> Index (f a) (f b)
    Shape e -> Shape (f e)
    Undef t -> Undef t

infixr :|
data SimList = (forall env' a'. Exp env' a' -> Exp env' a') :| SimList
             | SimEnd

simfix :: SimList -> Exp env a -> Exp env a
simfix list = \e -> let e' = looponce list e
                    in case geq e e' of
                         Just Refl -> e'
                         Nothing -> simfix list e'
  where
    looponce :: SimList -> Exp env a -> Exp env a
    looponce SimEnd e = e
    looponce (f :| l) e = looponce l (f e)