aboutsummaryrefslogtreecommitdiff
path: root/test/Main.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-10-03 23:05:46 +0200
committerTom Smeding <tom@tomsmeding.com>2025-10-03 23:05:46 +0200
commit8bcf09ed30595c7ffcdc41211a53046fefa35a44 (patch)
treeb3ad3b25bf198083d5ebd7a94fbbb0373edc3511 /test/Main.hs
parent4772025626d78127536c341c38052d23ca953ae3 (diff)
Add a test
Old code
Diffstat (limited to 'test/Main.hs')
-rw-r--r--test/Main.hs100
1 files changed, 100 insertions, 0 deletions
diff --git a/test/Main.hs b/test/Main.hs
new file mode 100644
index 0000000..208d3d6
--- /dev/null
+++ b/test/Main.hs
@@ -0,0 +1,100 @@
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+module Main where
+
+import Data.Expr.SharingRecovery
+import Data.Type.Equality
+
+
+data Ty a where
+ TInt :: Ty Int
+ TFloat :: Ty Float
+ TBool :: Ty Bool
+deriving instance Show (Ty a)
+
+instance TestEquality Ty where
+ testEquality TInt TInt = Just Refl
+ testEquality TInt _ = Nothing
+ testEquality TFloat TFloat = Just Refl
+ testEquality TFloat _ = Nothing
+ testEquality TBool TBool = Just Refl
+ testEquality TBool _ = Nothing
+
+type family IsOrdTy a where
+ IsOrdTy Int = True
+ IsOrdTy Float = True
+ IsOrdTy _ = False
+
+data Unop a b where
+ UONeg :: Ty a -> Unop a a
+ UONot :: Unop Bool Bool
+deriving instance Show (Unop a b)
+
+data Binop a b c where
+ BOAdd :: Ty a -> Binop a a a
+ BOSub :: Ty a -> Binop a a a
+ BOMul :: Ty a -> Binop a a a
+ BOAnd :: Binop Bool Bool Bool
+ BOOr :: Binop Bool Bool Bool
+ BOLt :: IsOrdTy a ~ True => Ty a -> Binop a a Bool
+ BOLeq :: IsOrdTy a ~ True => Ty a -> Binop a a Bool
+ BOEq :: IsOrdTy a ~ True => Ty a -> Binop a a Bool
+ BONeq :: IsOrdTy a ~ True => Ty a -> Binop a a Bool
+deriving instance Show (Binop a b c)
+
+data Lang r a where
+ Un :: Unop a b -> r a -> Lang r b
+ Bin :: Binop a b c -> r a -> r b -> Lang r c
+ Cond :: r Bool -> r a -> r a -> Lang r a
+ Cnst :: Show a => a -> Lang r a -- there's a type in the BExpr in the end, no need for one here
+deriving instance (forall b. Show (r b)) => Show (Lang r a)
+
+instance Functor1 Lang
+instance Traversable1 Lang where
+ traverse1 f = \case
+ Un op x -> Un op <$> f x
+ Bin op x y -> Bin op <$> f x <*> f y
+ Cond x y z -> Cond <$> f x <*> f y <*> f z
+ Cnst v -> pure (Cnst v)
+
+class KnownTy a where knownTy :: Ty a
+instance KnownTy Int where knownTy = TInt
+instance KnownTy Float where knownTy = TFloat
+instance KnownTy Bool where knownTy = TBool
+
+type Expr v = PHOASExpr Ty v Lang
+
+cond :: KnownTy a => Expr v Bool -> Expr v a -> Expr v a -> Expr v a
+cond a b c = PHOASOp knownTy (Cond a b c)
+
+(.<), (.<=), (.>), (.>=) :: (KnownTy a, IsOrdTy a ~ True) => Expr v a -> Expr v a -> Expr v Bool
+a .< b = PHOASOp TBool (Bin (BOLt knownTy) a b)
+a .<= b = PHOASOp TBool (Bin (BOLeq knownTy) a b)
+(.>) = flip (.<)
+(.>=) = flip (.<=)
+infix 4 .<
+infix 4 .<=
+infix 4 .>
+infix 4 .>=
+
+instance (KnownTy a, IsOrdTy a ~ True, Num a, Show a) => Num (Expr v a) where
+ a + b = PHOASOp knownTy (Bin (BOAdd knownTy) a b)
+ a - b = PHOASOp knownTy (Bin (BOSub knownTy) a b)
+ a * b = PHOASOp knownTy (Bin (BOMul knownTy) a b)
+ negate a = PHOASOp knownTy (Un (UONeg knownTy) a)
+ abs a = cond (a .< 0) (-a) a
+ signum a = cond (a .< 0) (-1) (cond (a .> 0) 1 0)
+ fromInteger n = PHOASOp knownTy (Cnst (fromInteger n))
+
+main :: IO ()
+main = do
+ print $ sharingRecovery @Lang @_ $
+ let a = 2 ; b = 3 :: Expr v Int
+ in a + b .< b + a