From 5a0ce21e12e765125ad8068e919cf97b70df8257 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 28 Aug 2024 16:10:58 +0200 Subject: Implement sorting of floated expressions --- test/Arith.hs | 103 +++++++++++++++++++++++++++++++++++++++++++++++++++ test/Main.hs | 116 +++++++++++++--------------------------------------------- 2 files changed, 129 insertions(+), 90 deletions(-) create mode 100644 test/Arith.hs (limited to 'test') diff --git a/test/Arith.hs b/test/Arith.hs new file mode 100644 index 0000000..c34baa8 --- /dev/null +++ b/test/Arith.hs @@ -0,0 +1,103 @@ +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE QuantifiedConstraints #-} +module Arith where + +import Data.Type.Equality + +import Data.Expr.SharingRecovery + + +data Typ t where + TInt :: Typ Int + TBool :: Typ Bool + TPair :: Typ a -> Typ b -> Typ (a, b) + TFun :: Typ a -> Typ b -> Typ (a -> b) +deriving instance Show (Typ t) + +instance TestEquality Typ where + testEquality TInt TInt = Just Refl + testEquality TBool TBool = Just Refl + testEquality (TPair a b) (TPair a' b') + | Just Refl <- testEquality a a' + , Just Refl <- testEquality b b' + = Just Refl + testEquality (TFun a b) (TFun a' b') + | Just Refl <- testEquality a a' + , Just Refl <- testEquality b b' + = Just Refl + testEquality _ _ = Nothing + +class KnownType t where τ :: Typ t +instance KnownType Int where τ = TInt +instance KnownType Bool where τ = TBool +instance (KnownType a, KnownType b) => KnownType (a, b) where τ = TPair τ τ +instance (KnownType a, KnownType b) => KnownType (a -> b) where τ = TFun τ τ + +data PrimOp a b where + POAddI :: PrimOp (Int, Int) Int + POMulI :: PrimOp (Int, Int) Int + POEqI :: PrimOp (Int, Int) Bool +deriving instance Show (PrimOp a b) + +opType2 :: PrimOp a b -> Typ b +opType2 = \case + POAddI -> TInt + POMulI -> TInt + POEqI -> TBool + +data Fixity = Infix | Prefix + deriving (Show) + +primOpPrec :: PrimOp a b -> (Int, (Int, Int)) +primOpPrec POAddI = (6, (6, 7)) +primOpPrec POMulI = (7, (7, 8)) +primOpPrec POEqI = (4, (5, 5)) + +prettyPrimOp :: Fixity -> PrimOp a b -> ShowS +prettyPrimOp fix op = + let s = case op of + POAddI -> "+" + POMulI -> "*" + POEqI -> "==" + in showString $ case fix of + Infix -> s + Prefix -> "(" ++ s ++ ")" + +data ArithF r t where + A_Prim :: PrimOp a b -> r a -> ArithF r b + A_Pair :: r a -> r b -> ArithF r (a, b) + A_If :: r Bool -> r a -> r a -> ArithF r a +deriving instance (forall a. Show (r a)) => Show (ArithF r t) + +instance Functor1 ArithF +instance Traversable1 ArithF where + traverse1 f (A_Prim op x) = A_Prim op <$> f x + traverse1 f (A_Pair x y) = A_Pair <$> f x <*> f y + traverse1 f (A_If x y z) = A_If <$> f x <*> f y <*> f z + +prettyArithF :: Monad m + => (forall a. Int -> BExpr Typ env ArithF a -> m ShowS) + -> Int -> ArithF (BExpr Typ env ArithF) t -> m ShowS +prettyArithF pr d = \case + A_Prim op (BOp _ (A_Pair a b)) -> do + let (dop, (dopL, dopR)) = primOpPrec op + a' <- pr dopL a + b' <- pr dopR b + return $ showParen (d > dop) $ a' . showString " " . prettyPrimOp Infix op . showString " " . b' + A_Prim op (BLet ty rhs e) -> + pr d (BLet ty rhs (BOp (opType2 op) (A_Prim op e))) + A_Prim op arg -> do + arg' <- pr 11 arg + return $ showParen (d > 10) $ prettyPrimOp Prefix op . showString " " . arg' + A_Pair a b -> do + a' <- pr 0 a + b' <- pr 0 b + return $ showString "(" . a' . showString ", " . b' . showString ")" + A_If a b c -> do + a' <- pr 0 a + b' <- pr 0 b + c' <- pr 0 c + return $ showParen (d > 0) $ showString "if " . a' . showString " then " . b' . showString " else " . c' diff --git a/test/Main.hs b/test/Main.hs index e7b303b..1a8d8e1 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -5,110 +5,46 @@ {-# LANGUAGE StandaloneDeriving #-} module Main where -import Data.Type.Equality - import Data.Expr.SharingRecovery +import Arith -data Typ t where - TInt :: Typ Int - TBool :: Typ Bool - TPair :: Typ a -> Typ b -> Typ (a, b) - TFun :: Typ a -> Typ b -> Typ (a -> b) -deriving instance Show (Typ t) - -instance TestEquality Typ where - testEquality TInt TInt = Just Refl - testEquality TBool TBool = Just Refl - testEquality (TPair a b) (TPair a' b') - | Just Refl <- testEquality a a' - , Just Refl <- testEquality b b' - = Just Refl - testEquality (TFun a b) (TFun a' b') - | Just Refl <- testEquality a a' - , Just Refl <- testEquality b b' - = Just Refl - testEquality _ _ = Nothing - -class KnownType t where τ :: Typ t -instance KnownType Int where τ = TInt -instance KnownType Bool where τ = TBool -instance (KnownType a, KnownType b) => KnownType (a, b) where τ = TPair τ τ -instance (KnownType a, KnownType b) => KnownType (a -> b) where τ = TFun τ τ -data PrimOp a b where - POAddI :: PrimOp (Int, Int) Int - POMulI :: PrimOp (Int, Int) Int - POEqI :: PrimOp (Int, Int) Bool -deriving instance Show (PrimOp a b) +-- TODO: test cyclic expressions -data Fixity = Infix | Prefix - deriving (Show) -primOpPrec :: PrimOp a b -> (Int, (Int, Int)) -primOpPrec POAddI = (6, (6, 7)) -primOpPrec POMulI = (7, (7, 8)) -primOpPrec POEqI = (4, (5, 5)) +a_bin :: (KnownType a, KnownType b, KnownType c) + => PrimOp (a, b) c + -> PHOASExpr Typ v ArithF a + -> PHOASExpr Typ v ArithF b + -> PHOASExpr Typ v ArithF c +a_bin op a b = PHOASOp τ (A_Prim op (PHOASOp τ (A_Pair a b))) -prettyPrimOp :: Fixity -> PrimOp a b -> ShowS -prettyPrimOp fix op = - let s = case op of - POAddI -> "+" - POMulI -> "*" - POEqI -> "==" - in showString $ case fix of - Infix -> s - Prefix -> "(" ++ s ++ ")" +lam :: (KnownType a, KnownType b) + => (PHOASExpr Typ v f a -> PHOASExpr Typ v f b) -> PHOASExpr Typ v f (a -> b) +lam f = PHOASLam τ τ $ \arg -> f (PHOASVar τ arg) -data ArithF r t where - A_Prim :: PrimOp a b -> r a -> ArithF r b - A_Pair :: r a -> r b -> ArithF r (a, b) - A_If :: r Bool -> r a -> r a -> ArithF r a -deriving instance (forall a. Show (r a)) => Show (ArithF r t) +(+!) :: PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int +(+!) = a_bin POAddI -instance Functor1 ArithF -instance Traversable1 ArithF where - traverse1 f (A_Prim op x) = A_Prim op <$> f x - traverse1 f (A_Pair x y) = A_Pair <$> f x <*> f y - traverse1 f (A_If x y z) = A_If <$> f x <*> f y <*> f z - -prettyArithF :: Monad m - => (forall a. Int -> BExpr typ env ArithF a -> m ShowS) - -> Int -> ArithF (BExpr typ env ArithF) t -> m ShowS -prettyArithF pr d = \case - A_Prim op (BOp _ (A_Pair a b)) -> do - let (dop, (dopL, dopR)) = primOpPrec op - a' <- pr dopL a - b' <- pr dopR b - return $ showParen (d > dop) $ a' . showString " " . prettyPrimOp Infix op . showString " " . b' - A_Prim op arg -> do - arg' <- pr 11 arg - return $ showParen (d > 10) $ prettyPrimOp Prefix op . showString " " . arg' - A_Pair a b -> do - a' <- pr 0 a - b' <- pr 0 b - return $ showString "(" . a' . showString ", " . b' . showString ")" - A_If a b c -> do - a' <- pr 0 a - b' <- pr 0 b - c' <- pr 0 c - return $ showParen (d > 0) $ showString "if " . a' . showString " then " . b' . showString " else " . c' +(*!) :: PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int +(*!) = a_bin POMulI -- λx. x + x ea_1 :: PHOASExpr Typ v ArithF (Int -> Int) -ea_1 = - PHOASLam τ τ $ \arg -> - PHOASOp τ (A_Prim POAddI - (PHOASOp τ (A_Pair (PHOASVar τ arg) (PHOASVar τ arg)))) +ea_1 = lam $ \arg -> arg +! arg -- λx. let y = x + x in y * y ea_2 :: PHOASExpr Typ v ArithF (Int -> Int) -ea_2 = - PHOASLam τ τ $ \arg -> - let y = PHOASOp τ (A_Prim POAddI - (PHOASOp τ (A_Pair (PHOASVar τ arg) (PHOASVar τ arg)))) - in PHOASOp τ (A_Prim POMulI - (PHOASOp τ (A_Pair y y))) +ea_2 = lam $ \arg -> let y = arg +! arg + in y *! y + +ea_3 :: PHOASExpr Typ v ArithF (Int -> Int) +ea_3 = lam $ \arg -> + let y = arg +! arg + x = y *! arg + -- in (y +! x) +! (x +! y) + in (x +! y) +! (y +! x) main :: IO () -main = putStrLn $ prettyBExpr prettyArithF (sharingRecovery ea_2) +main = putStrLn $ prettyBExpr prettyArithF (sharingRecovery ea_3) -- cgit v1.2.3-70-g09d2