aboutsummaryrefslogtreecommitdiff
path: root/test/Main.hs
blob: e7b303b434fb83ecdb26cd7b0b17de45750bffef (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
module Main where

import Data.Type.Equality

import Data.Expr.SharingRecovery


data Typ t where
  TInt :: Typ Int
  TBool :: Typ Bool
  TPair :: Typ a -> Typ b -> Typ (a, b)
  TFun :: Typ a -> Typ b -> Typ (a -> b)
deriving instance Show (Typ t)

instance TestEquality Typ where
  testEquality TInt TInt = Just Refl
  testEquality TBool TBool = Just Refl
  testEquality (TPair a b) (TPair a' b')
    | Just Refl <- testEquality a a'
    , Just Refl <- testEquality b b'
    = Just Refl
  testEquality (TFun a b) (TFun a' b')
    | Just Refl <- testEquality a a'
    , Just Refl <- testEquality b b'
    = Just Refl
  testEquality _ _ = Nothing

class KnownType t where τ :: Typ t
instance KnownType Int where τ = TInt
instance KnownType Bool where τ = TBool
instance (KnownType a, KnownType b) => KnownType (a, b) where τ = TPair τ τ
instance (KnownType a, KnownType b) => KnownType (a -> b) where τ = TFun τ τ

data PrimOp a b where
  POAddI :: PrimOp (Int, Int) Int
  POMulI :: PrimOp (Int, Int) Int
  POEqI :: PrimOp (Int, Int) Bool
deriving instance Show (PrimOp a b)

data Fixity = Infix | Prefix
  deriving (Show)

primOpPrec :: PrimOp a b -> (Int, (Int, Int))
primOpPrec POAddI = (6, (6, 7))
primOpPrec POMulI = (7, (7, 8))
primOpPrec POEqI = (4, (5, 5))

prettyPrimOp :: Fixity -> PrimOp a b -> ShowS
prettyPrimOp fix op =
  let s = case op of
            POAddI -> "+"
            POMulI -> "*"
            POEqI -> "=="
  in showString $ case fix of
       Infix -> s
       Prefix -> "(" ++ s ++ ")"

data ArithF r t where
  A_Prim :: PrimOp a b -> r a -> ArithF r b
  A_Pair :: r a -> r b -> ArithF r (a, b)
  A_If :: r Bool -> r a -> r a -> ArithF r a
deriving instance (forall a. Show (r a)) => Show (ArithF r t)

instance Functor1 ArithF
instance Traversable1 ArithF where
  traverse1 f (A_Prim op x) = A_Prim op <$> f x
  traverse1 f (A_Pair x y) = A_Pair <$> f x <*> f y
  traverse1 f (A_If x y z) = A_If <$> f x <*> f y <*> f z

prettyArithF :: Monad m
             => (forall a. Int -> BExpr typ env ArithF a -> m ShowS)
             -> Int -> ArithF (BExpr typ env ArithF) t -> m ShowS
prettyArithF pr d = \case
  A_Prim op (BOp _ (A_Pair a b)) -> do
    let (dop, (dopL, dopR)) = primOpPrec op
    a' <- pr dopL a
    b' <- pr dopR b
    return $ showParen (d > dop) $ a' . showString " " . prettyPrimOp Infix op . showString " " . b'
  A_Prim op arg -> do
    arg' <- pr 11 arg
    return $ showParen (d > 10) $ prettyPrimOp Prefix op . showString " " . arg'
  A_Pair a b -> do
    a' <- pr 0 a
    b' <- pr 0 b
    return $ showString "(" . a' . showString ", " . b' . showString ")"
  A_If a b c -> do
    a' <- pr 0 a
    b' <- pr 0 b
    c' <- pr 0 c
    return $ showParen (d > 0) $ showString "if " . a' . showString " then " . b' . showString " else " . c'

-- λx. x + x
ea_1 :: PHOASExpr Typ v ArithF (Int -> Int)
ea_1 =
  PHOASLam τ τ $ \arg ->
    PHOASOp τ (A_Prim POAddI
      (PHOASOp τ (A_Pair (PHOASVar τ arg) (PHOASVar τ arg))))

-- λx. let y = x + x in y * y
ea_2 :: PHOASExpr Typ v ArithF (Int -> Int)
ea_2 =
  PHOASLam τ τ $ \arg ->
    let y = PHOASOp τ (A_Prim POAddI
              (PHOASOp τ (A_Pair (PHOASVar τ arg) (PHOASVar τ arg))))
    in PHOASOp τ (A_Prim POMulI
         (PHOASOp τ (A_Pair y y)))

main :: IO ()
main = putStrLn $ prettyBExpr prettyArithF (sharingRecovery ea_2)