aboutsummaryrefslogtreecommitdiff
path: root/test/Main.hs
blob: 208d3d6caa061a9d55eae4a5bdc75fdbe982e58e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE QuantifiedConstraints #-}
module Main where

import Data.Expr.SharingRecovery
import Data.Type.Equality


data Ty a where
  TInt :: Ty Int
  TFloat :: Ty Float
  TBool :: Ty Bool
deriving instance Show (Ty a)

instance TestEquality Ty where
  testEquality TInt TInt = Just Refl
  testEquality TInt _ = Nothing
  testEquality TFloat TFloat = Just Refl
  testEquality TFloat _ = Nothing
  testEquality TBool TBool = Just Refl
  testEquality TBool _ = Nothing

type family IsOrdTy a where
  IsOrdTy Int = True
  IsOrdTy Float = True
  IsOrdTy _ = False

data Unop a b where
  UONeg :: Ty a -> Unop a a
  UONot :: Unop Bool Bool
deriving instance Show (Unop a b)

data Binop a b c where
  BOAdd :: Ty a -> Binop a a a
  BOSub :: Ty a -> Binop a a a
  BOMul :: Ty a -> Binop a a a
  BOAnd :: Binop Bool Bool Bool
  BOOr :: Binop Bool Bool Bool
  BOLt :: IsOrdTy a ~ True => Ty a -> Binop a a Bool
  BOLeq :: IsOrdTy a ~ True => Ty a -> Binop a a Bool
  BOEq :: IsOrdTy a ~ True => Ty a -> Binop a a Bool
  BONeq :: IsOrdTy a ~ True => Ty a -> Binop a a Bool
deriving instance Show (Binop a b c)

data Lang r a where
  Un :: Unop a b -> r a -> Lang r b
  Bin :: Binop a b c -> r a -> r b -> Lang r c
  Cond :: r Bool -> r a -> r a -> Lang r a
  Cnst :: Show a => a -> Lang r a  -- there's a type in the BExpr in the end, no need for one here
deriving instance (forall b. Show (r b)) => Show (Lang r a)

instance Functor1 Lang
instance Traversable1 Lang where
  traverse1 f = \case
    Un op x -> Un op <$> f x
    Bin op x y -> Bin op <$> f x <*> f y
    Cond x y z -> Cond <$> f x <*> f y <*> f z
    Cnst v -> pure (Cnst v)

class KnownTy a where knownTy :: Ty a
instance KnownTy Int where knownTy = TInt
instance KnownTy Float where knownTy = TFloat
instance KnownTy Bool where knownTy = TBool

type Expr v = PHOASExpr Ty v Lang

cond :: KnownTy a => Expr v Bool -> Expr v a -> Expr v a -> Expr v a
cond a b c = PHOASOp knownTy (Cond a b c)

(.<), (.<=), (.>), (.>=) :: (KnownTy a, IsOrdTy a ~ True) => Expr v a -> Expr v a -> Expr v Bool
a .< b = PHOASOp TBool (Bin (BOLt knownTy) a b)
a .<= b = PHOASOp TBool (Bin (BOLeq knownTy) a b)
(.>) = flip (.<)
(.>=) = flip (.<=)
infix 4 .<
infix 4 .<=
infix 4 .>
infix 4 .>=

instance (KnownTy a, IsOrdTy a ~ True, Num a, Show a) => Num (Expr v a) where
  a + b = PHOASOp knownTy (Bin (BOAdd knownTy) a b)
  a - b = PHOASOp knownTy (Bin (BOSub knownTy) a b)
  a * b = PHOASOp knownTy (Bin (BOMul knownTy) a b)
  negate a = PHOASOp knownTy (Un (UONeg knownTy) a)
  abs a = cond (a .< 0) (-a) a
  signum a = cond (a .< 0) (-1) (cond (a .> 0) 1 0)
  fromInteger n = PHOASOp knownTy (Cnst (fromInteger n))

main :: IO ()
main = do
  print $ sharingRecovery @Lang @_ $
    let a = 2 ; b = 3 :: Expr v Int
    in a + b .< b + a