aboutsummaryrefslogtreecommitdiff
path: root/test/Main.hs
blob: 5a4d3355a188760cf05ad8c7c1c489dbd0c1b5c1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
module Main where

import Data.Expr.SharingRecovery
import Data.Expr.SharingRecovery.Internal

import Arith


-- TODO: test cyclic expressions


a_bin :: (KnownType a, KnownType b, KnownType c)
      => PrimOp (a, b) c
      -> PHOASExpr Typ v ArithF a
      -> PHOASExpr Typ v ArithF b
      -> PHOASExpr Typ v ArithF c
a_bin op a b = PHOASOp τ (A_Prim op (PHOASOp τ (A_Pair a b)))

lam :: (KnownType a, KnownType b)
    => (PHOASExpr Typ v f a -> PHOASExpr Typ v f b) -> PHOASExpr Typ v f (a -> b)
lam f = PHOASLam τ τ $ \arg -> f (PHOASVar τ arg)

(+!) :: PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int
(+!) = a_bin POAddI

(*!) :: PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int
(*!) = a_bin POMulI

-- λx. x + x
ea_1 :: PHOASExpr Typ v ArithF (Int -> Int)
ea_1 = lam $ \arg -> arg +! arg

-- λx. let y = x + x in y * y
ea_2 :: PHOASExpr Typ v ArithF (Int -> Int)
ea_2 = lam $ \arg -> let y = arg +! arg
                     in y *! y

ea_3 :: PHOASExpr Typ v ArithF (Int -> Int)
ea_3 = lam $ \arg ->
  let y = arg +! arg
      x = y *! arg
  -- in (y +! x) +! (x +! y)
  in (x +! y) +! (y +! x)

main :: IO ()
main = putStrLn $ prettyBExpr prettyArithF (sharingRecovery ea_3)