diff options
Diffstat (limited to 'test/Main.hs')
| -rw-r--r-- | test/Main.hs | 113 |
1 files changed, 81 insertions, 32 deletions
diff --git a/test/Main.hs b/test/Main.hs index 5a4d335..208d3d6 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,51 +1,100 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE QuantifiedConstraints #-} module Main where import Data.Expr.SharingRecovery -import Data.Expr.SharingRecovery.Internal +import Data.Type.Equality + + +data Ty a where + TInt :: Ty Int + TFloat :: Ty Float + TBool :: Ty Bool +deriving instance Show (Ty a) -import Arith +instance TestEquality Ty where + testEquality TInt TInt = Just Refl + testEquality TInt _ = Nothing + testEquality TFloat TFloat = Just Refl + testEquality TFloat _ = Nothing + testEquality TBool TBool = Just Refl + testEquality TBool _ = Nothing +type family IsOrdTy a where + IsOrdTy Int = True + IsOrdTy Float = True + IsOrdTy _ = False --- TODO: test cyclic expressions +data Unop a b where + UONeg :: Ty a -> Unop a a + UONot :: Unop Bool Bool +deriving instance Show (Unop a b) +data Binop a b c where + BOAdd :: Ty a -> Binop a a a + BOSub :: Ty a -> Binop a a a + BOMul :: Ty a -> Binop a a a + BOAnd :: Binop Bool Bool Bool + BOOr :: Binop Bool Bool Bool + BOLt :: IsOrdTy a ~ True => Ty a -> Binop a a Bool + BOLeq :: IsOrdTy a ~ True => Ty a -> Binop a a Bool + BOEq :: IsOrdTy a ~ True => Ty a -> Binop a a Bool + BONeq :: IsOrdTy a ~ True => Ty a -> Binop a a Bool +deriving instance Show (Binop a b c) -a_bin :: (KnownType a, KnownType b, KnownType c) - => PrimOp (a, b) c - -> PHOASExpr Typ v ArithF a - -> PHOASExpr Typ v ArithF b - -> PHOASExpr Typ v ArithF c -a_bin op a b = PHOASOp τ (A_Prim op (PHOASOp τ (A_Pair a b))) +data Lang r a where + Un :: Unop a b -> r a -> Lang r b + Bin :: Binop a b c -> r a -> r b -> Lang r c + Cond :: r Bool -> r a -> r a -> Lang r a + Cnst :: Show a => a -> Lang r a -- there's a type in the BExpr in the end, no need for one here +deriving instance (forall b. Show (r b)) => Show (Lang r a) -lam :: (KnownType a, KnownType b) - => (PHOASExpr Typ v f a -> PHOASExpr Typ v f b) -> PHOASExpr Typ v f (a -> b) -lam f = PHOASLam τ τ $ \arg -> f (PHOASVar τ arg) +instance Functor1 Lang +instance Traversable1 Lang where + traverse1 f = \case + Un op x -> Un op <$> f x + Bin op x y -> Bin op <$> f x <*> f y + Cond x y z -> Cond <$> f x <*> f y <*> f z + Cnst v -> pure (Cnst v) -(+!) :: PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -(+!) = a_bin POAddI +class KnownTy a where knownTy :: Ty a +instance KnownTy Int where knownTy = TInt +instance KnownTy Float where knownTy = TFloat +instance KnownTy Bool where knownTy = TBool -(*!) :: PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -(*!) = a_bin POMulI +type Expr v = PHOASExpr Ty v Lang --- λx. x + x -ea_1 :: PHOASExpr Typ v ArithF (Int -> Int) -ea_1 = lam $ \arg -> arg +! arg +cond :: KnownTy a => Expr v Bool -> Expr v a -> Expr v a -> Expr v a +cond a b c = PHOASOp knownTy (Cond a b c) --- λx. let y = x + x in y * y -ea_2 :: PHOASExpr Typ v ArithF (Int -> Int) -ea_2 = lam $ \arg -> let y = arg +! arg - in y *! y +(.<), (.<=), (.>), (.>=) :: (KnownTy a, IsOrdTy a ~ True) => Expr v a -> Expr v a -> Expr v Bool +a .< b = PHOASOp TBool (Bin (BOLt knownTy) a b) +a .<= b = PHOASOp TBool (Bin (BOLeq knownTy) a b) +(.>) = flip (.<) +(.>=) = flip (.<=) +infix 4 .< +infix 4 .<= +infix 4 .> +infix 4 .>= -ea_3 :: PHOASExpr Typ v ArithF (Int -> Int) -ea_3 = lam $ \arg -> - let y = arg +! arg - x = y *! arg - -- in (y +! x) +! (x +! y) - in (x +! y) +! (y +! x) +instance (KnownTy a, IsOrdTy a ~ True, Num a, Show a) => Num (Expr v a) where + a + b = PHOASOp knownTy (Bin (BOAdd knownTy) a b) + a - b = PHOASOp knownTy (Bin (BOSub knownTy) a b) + a * b = PHOASOp knownTy (Bin (BOMul knownTy) a b) + negate a = PHOASOp knownTy (Un (UONeg knownTy) a) + abs a = cond (a .< 0) (-a) a + signum a = cond (a .< 0) (-1) (cond (a .> 0) 1 0) + fromInteger n = PHOASOp knownTy (Cnst (fromInteger n)) main :: IO () -main = putStrLn $ prettyBExpr prettyArithF (sharingRecovery ea_3) +main = do + print $ sharingRecovery @Lang @_ $ + let a = 2 ; b = 3 :: Expr v Int + in a + b .< b + a |
