aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2024-08-28 16:10:58 +0200
committerTom Smeding <t.j.smeding@uu.nl>2024-08-28 16:10:58 +0200
commit5a0ce21e12e765125ad8068e919cf97b70df8257 (patch)
treeed38cf21945c6a2b0434c23a35b3136935dbaf0e
parent869be329dd05eede1dd1adb3c3b6ce2340074818 (diff)
Implement sorting of floated expressions
-rw-r--r--sharing-recovery.cabal2
-rw-r--r--src/Data/Expr/SharingRecovery.hs172
-rw-r--r--src/Data/StableName/Extra.hs27
-rw-r--r--test/Arith.hs103
-rw-r--r--test/Main.hs116
5 files changed, 288 insertions, 132 deletions
diff --git a/sharing-recovery.cabal b/sharing-recovery.cabal
index df63c42..b74283b 100644
--- a/sharing-recovery.cabal
+++ b/sharing-recovery.cabal
@@ -23,6 +23,8 @@ library
test-suite test
type: exitcode-stdio-1.0
main-is: Main.hs
+ other-modules:
+ Arith
hs-source-dirs: test
build-depends:
sharing-recovery,
diff --git a/src/Data/Expr/SharingRecovery.hs b/src/Data/Expr/SharingRecovery.hs
index cdb64eb..f9d27e6 100644
--- a/src/Data/Expr/SharingRecovery.hs
+++ b/src/Data/Expr/SharingRecovery.hs
@@ -17,7 +17,7 @@ module Data.Expr.SharingRecovery where
import Control.Applicative ((<|>))
import Control.Monad.Trans.State.Strict
-import Data.Bifunctor (second)
+import Data.Bifunctor (first, second)
import Data.Char (chr, ord)
import Data.Functor.Const
import Data.Functor.Identity
@@ -25,12 +25,17 @@ import Data.Functor.Product
import Data.Hashable
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HM
+import Data.List (sortBy, intersperse)
+import Data.Maybe (fromMaybe)
+import Data.Ord (comparing)
import Data.Some
import Data.Type.Equality
import GHC.StableName
import Numeric.Natural
import Unsafe.Coerce (unsafeCoerce)
+-- import Debug.Trace
+
import Data.StableName.Extra
@@ -41,6 +46,16 @@ import Data.StableName.Extra
-- is a good opportunity to try to do better.
+withMoreState :: Functor m => b -> StateT (s, b) m a -> StateT s m (a, b)
+withMoreState b0 (StateT f) =
+ StateT $ \s -> (\(x, (s2, b)) -> ((x, b), s2)) <$> f (s, b0)
+
+withLessState :: Functor m => (s -> (s', b)) -> (s' -> b -> s) -> StateT s' m a -> StateT s m a
+withLessState split restore (StateT f) =
+ StateT $ \s -> let (s', b) = split s
+ in second (flip restore b) <$> f s'
+
+
class Functor1 f where
fmap1 :: (forall b. g b -> h b) -> f g a -> f h a
@@ -91,49 +106,85 @@ instance Eq (SomeNameFor typ f) where
instance Hashable (SomeNameFor typ f) where
hashWithSalt salt (SomeNameFor name) = hashWithSalt salt name
--- | The number of times a particular name is visited in a preorder traversal
--- of the PHOAS expression, excluding children of nodes upon second or later
--- visit. That is to say: only the nodes that are visited in a preorder
--- traversal that skips repeated subtrees, are counted.
-type OccMap typ f = HashMap (SomeNameFor typ f) Natural
+prettyPExpr :: Traversable1 f => Int -> PExpr typ f t -> ShowS
+prettyPExpr d = \case
+ PStub (NameFor name) _ -> showString (showStableName name)
+ POp (NameFor name) _ args ->
+ let (argslist, _) = traverse1 (\arg -> ([Some arg], Const ())) args
+ argslist' = map (\(Some arg) -> prettyPExpr 0 arg) argslist
+ in showParen (d > 10) $
+ showString ("<" ++ showStableName name ++ ">(")
+ . foldr (.) id (intersperse (showString ", ") argslist')
+ . showString ")"
+ PLam (NameFor name) _ _ (Tag tag) body ->
+ showParen (d > 0) $
+ showString ("λ" ++ showStableName name ++ " x" ++ show tag ++ ". ") . prettyPExpr 0 body
+ PVar _ (Tag tag) -> showString ("x" ++ show tag)
+
+-- | For each name:
+--
+-- 1. The number of times the name is visited in a preorder traversal of the
+-- PHOAS expression, excluding children of nodes upon second or later visit.
+-- That is to say: only the nodes that are visited in a preorder traversal
+-- that skips repeated subtrees, are counted.
+-- 2. The height of the expression indicated by the name.
+--
+-- Missing names have not been seen yet, and have unknown height.
+type OccMap typ f = HashMap (SomeNameFor typ f) (Natural, Natural)
pruneExpr :: Traversable1 f => (forall v. PHOASExpr typ v f t) -> (OccMap typ f, PExpr typ f t)
pruneExpr term =
- let (term', (_, mp)) = runState (pruneExpr' term) (0, mempty)
+ let ((term', _), (_, mp)) = runState (pruneExpr' term) (0, mempty)
in (mp, term')
-pruneExpr' :: Traversable1 f => PHOASExpr typ Tag f t -> State (Natural, OccMap typ f) (PExpr typ f t)
+-- | Returns pruned expression with its height.
+pruneExpr' :: Traversable1 f => PHOASExpr typ Tag f t -> State (Natural, OccMap typ f) (PExpr typ f t, Natural)
pruneExpr' = \case
orig@(PHOASOp ty args) -> do
let name = makeStableName' orig
- seenBefore <- checkVisited name
- if seenBefore
- then pure $ PStub (NameFor name) ty
- else POp (NameFor name) ty <$> traverse1 pruneExpr' args
+ mheight <- gets (fmap snd . HM.lookup (SomeNameFor (NameFor name)) . snd)
+ case mheight of
+ -- already visited
+ Just height -> do
+ modify (second (HM.adjust (first (+1)) (SomeNameFor (NameFor name))))
+ pure (PStub (NameFor name) ty, height)
+ -- first visit
+ Nothing -> do
+ -- Traverse the arguments, collecting the maximum height in an
+ -- additional piece of state.
+ (args', maxhei) <-
+ withMoreState 0 $
+ traverse1 (\arg -> do
+ (arg', hei) <- withLessState id (,) (pruneExpr' arg)
+ modify (second (hei `max`))
+ return arg')
+ args
+ -- Record this node
+ modify (second (HM.insert (SomeNameFor (NameFor name)) (1, 1 + maxhei)))
+ pure (POp (NameFor name) ty args', 1 + maxhei)
orig@(PHOASLam tyf tyarg f) -> do
let name = makeStableName' orig
- seenBefore <- checkVisited name
- if seenBefore
- then pure $ PStub (NameFor name) tyf
- else do
- tag <- state (\(i, mp) -> (Tag i, (i + 1, mp)))
+ mheight <- gets (fmap snd . HM.lookup (SomeNameFor (NameFor name)) . snd)
+ case mheight of
+ -- already visited
+ Just height -> do
+ modify (second (HM.adjust (first (+1)) (SomeNameFor (NameFor name))))
+ pure (PStub (NameFor name) tyf, height)
+ -- first visit
+ Nothing -> do
+ tag <- Tag <$> gets fst
+ modify (first (+1))
let body = f tag
- PLam (NameFor name) tyf tyarg tag <$> pruneExpr' body
+ (body', bodyhei) <- pruneExpr' body
+ modify (second (HM.insert (SomeNameFor (NameFor name)) (1, 1 + bodyhei)))
+ pure (PLam (NameFor name) tyf tyarg tag body', 1 + bodyhei)
- PHOASVar ty tag -> pure $ PVar ty tag
- where
- checkVisited name = do
- occmap <- gets snd
- let (seenBefore, occmap') =
- HM.alterF (\case Nothing -> (False, Just 1)
- Just n -> (True, Just (n + 1)))
- (SomeNameFor (NameFor name))
- occmap
- modify (second (const occmap'))
- return seenBefore
+ PHOASVar ty tag -> pure (PVar ty tag, 1)
+-- TODO: Replace "lift" with "float"
+
-- | Lifted expression: a bunch of to-be let bound expressions on top of an
-- LExpr'. Because LExpr' is really just PExpr with the recursive positions
-- replaced by LExpr, LExpr should be seen as PExpr with a bunch of to-be let
@@ -145,12 +196,35 @@ data LExpr' typ f t where -- TODO: this could be an instantiation of (a general
LLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> LExpr typ f b -> LExpr' typ f (a -> b)
LVar :: typ a -> Tag a -> LExpr' typ f a
+prettyLExpr :: Traversable1 f => Int -> LExpr typ f t -> ShowS
+prettyLExpr d (LExpr [] e) = prettyLExpr' d e
+prettyLExpr d (LExpr lifted e) =
+ showString "["
+ . foldr (.) id (intersperse (showString ", ") (map (\(Some e') -> prettyLExpr 0 e') lifted))
+ . showString "] " . prettyLExpr' d e
+
+prettyLExpr' :: Traversable1 f => Int -> LExpr' typ f t -> ShowS
+prettyLExpr' d = \case
+ LStub (NameFor name) _ -> showString (showStableName name)
+ LOp (NameFor name) _ args ->
+ let (argslist, _) = traverse1 (\arg -> ([Some arg], Const ())) args
+ argslist' = map (\(Some arg) -> prettyLExpr 0 arg) argslist
+ in showParen (d > 10) $
+ showString ("<" ++ showStableName name ++ ">(")
+ . foldr (.) id (intersperse (showString ", ") argslist')
+ . showString ")"
+ LLam (NameFor name) _ _ (Tag tag) body ->
+ showParen (d > 0) $
+ showString ("λ" ++ showStableName name ++ " x" ++ show tag ++ ". ") . prettyLExpr 0 body
+ LVar _ (Tag tag) -> showString ("x" ++ show tag)
+
liftExpr :: Traversable1 f => OccMap typ f -> PExpr typ f t -> LExpr typ f t
liftExpr totals term = snd (liftExpr' totals term)
newtype FoundMap typ f = FoundMap
- (HashMap (SomeNameFor typ f) (Natural -- how many times seen
- ,Maybe (Some (LExpr typ f)))) -- the lifted subterm (once seen)
+ (HashMap (SomeNameFor typ f)
+ (Natural -- how many times seen
+ ,Maybe (Some (LExpr typ f), Natural))) -- the lifted subterm with its height (once seen)
instance Semigroup (FoundMap typ f) where
FoundMap m1 <> FoundMap m2 = FoundMap $
@@ -161,10 +235,13 @@ instance Monoid (FoundMap typ f) where
liftExpr' :: Traversable1 f => OccMap typ f -> PExpr typ f t -> (FoundMap typ f, LExpr typ f t)
liftExpr' _totals (PStub name ty) =
- (FoundMap $ HM.singleton (SomeNameFor name) (1, Nothing) -- Just (Some (LExpr [] (LStub name)))
+ -- trace ("Found stub: " ++ (case name of NameFor n -> showStableName n)) $
+ (FoundMap $ HM.singleton (SomeNameFor name) (1, Nothing)
,LExpr [] (LStub name ty))
-liftExpr' _totals (PVar ty tag) = (mempty, LExpr [] (LVar ty tag))
+liftExpr' _totals (PVar ty tag) =
+ -- trace ("Found var: " ++ show tag) $
+ (mempty, LExpr [] (LVar ty tag))
liftExpr' totals term =
let (FoundMap foundmap, name, termty, term') = case term of
@@ -178,19 +255,24 @@ liftExpr' totals term =
-- TODO: perhaps this HM.toList together with the foldr HM.delete can be a single traversal of the HashMap
saturated = [case mterm of
Just t -> (nm, t)
- Nothing -> error "Name saturated but no term found"
+ Nothing -> case nm of
+ SomeNameFor (NameFor n) ->
+ error $ "Name saturated (count=" ++ show count ++ ", totalcount=" ++ show totalcount ++ ") but no term found: " ++ showStableName n
| (nm, (count, mterm)) <- HM.toList foundmap
- , count == HM.findWithDefault 0 nm totals]
+ , let totalcount = fromMaybe 0 (fst <$> HM.lookup nm totals)
+ , count == totalcount]
foundmap' = foldr HM.delete foundmap (map fst saturated)
- lterm = LExpr (map snd saturated) term'
+ lterm = LExpr (map fst (sortBy (comparing snd) (map snd saturated))) term'
- in case HM.findWithDefault 0 (SomeNameFor name) totals of
- 1 -> (FoundMap foundmap', lterm)
- tot | tot > 1 -> (FoundMap (HM.insert (SomeNameFor name) (1, Just (Some lterm)) foundmap')
- ,LExpr [] (LStub name termty))
- | otherwise -> error "Term does not exist, yet we have it in hand"
+ in case HM.findWithDefault (0, undefined) (SomeNameFor name) totals of
+ (1, _) -> (FoundMap foundmap', lterm)
+ (tot, height)
+ | tot > 1 -> -- trace ("Inserting " ++ (case name of NameFor n -> showStableName n) ++ " into foundmap") $
+ (FoundMap (HM.insert (SomeNameFor name) (1, Just (Some lterm, height)) foundmap')
+ ,LExpr [] (LStub name termty))
+ | otherwise -> error "Term does not exist, yet we have it in hand"
-- | Untyped De Bruijn expression. No more names: there are lets now, and
@@ -334,4 +416,10 @@ retypeExpr' env (UBVar ty idx) =
sharingRecovery :: (Traversable1 f, TestEquality typ) => (forall v. PHOASExpr typ v f t) -> BExpr typ '[] f t
-sharingRecovery e = retypeExpr $ lowerExpr $ uncurry liftExpr $ pruneExpr e
+sharingRecovery e =
+ let (occmap, pexpr) = pruneExpr e
+ lexpr = liftExpr occmap pexpr
+ ubexpr = lowerExpr lexpr
+ in -- trace ("PExpr: " ++ prettyPExpr 0 pexpr "") $
+ -- trace ("LExpr: " ++ prettyLExpr 0 lexpr "") $
+ retypeExpr ubexpr
diff --git a/src/Data/StableName/Extra.hs b/src/Data/StableName/Extra.hs
index f568740..cf37cfe 100644
--- a/src/Data/StableName/Extra.hs
+++ b/src/Data/StableName/Extra.hs
@@ -1,10 +1,16 @@
{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE ExistentialQuantification #-}
{-# OPTIONS_GHC -fno-full-laziness -fno-cse #-}
module Data.StableName.Extra (
StableName,
makeStableName',
+ showStableName,
) where
+import Data.Hashable
+import Data.HashMap.Strict (HashMap)
+import qualified Data.HashMap.Strict as HM
+import Data.IORef
import GHC.StableName
import System.IO.Unsafe
@@ -15,3 +21,24 @@ import System.IO.Unsafe
{-# NOINLINE makeStableName' #-}
makeStableName' :: a -> StableName a
makeStableName' !x = unsafePerformIO (makeStableName x)
+
+
+data SomeStableName = forall a. SomeStableName (StableName a)
+
+instance Eq SomeStableName where
+ SomeStableName a == SomeStableName b = eqStableName a b
+
+instance Hashable SomeStableName where
+ hashWithSalt salt (SomeStableName name) = hashWithSalt salt name
+
+{-# NOINLINE showStableNameCache #-}
+showStableNameCache :: IORef (HashMap SomeStableName Int, Int)
+showStableNameCache = unsafePerformIO $ newIORef (mempty, 0)
+
+{-# NOINLINE showStableName #-}
+showStableName :: StableName a -> String
+showStableName name =
+ unsafePerformIO $ atomicModifyIORef' showStableNameCache $ \tup@(mp, nexti) ->
+ case HM.lookup (SomeStableName name) mp of
+ Just res -> (tup, '$' : show res)
+ Nothing -> ((HM.insert (SomeStableName name) nexti mp, nexti + 1), '$' : show nexti)
diff --git a/test/Arith.hs b/test/Arith.hs
new file mode 100644
index 0000000..c34baa8
--- /dev/null
+++ b/test/Arith.hs
@@ -0,0 +1,103 @@
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+module Arith where
+
+import Data.Type.Equality
+
+import Data.Expr.SharingRecovery
+
+
+data Typ t where
+ TInt :: Typ Int
+ TBool :: Typ Bool
+ TPair :: Typ a -> Typ b -> Typ (a, b)
+ TFun :: Typ a -> Typ b -> Typ (a -> b)
+deriving instance Show (Typ t)
+
+instance TestEquality Typ where
+ testEquality TInt TInt = Just Refl
+ testEquality TBool TBool = Just Refl
+ testEquality (TPair a b) (TPair a' b')
+ | Just Refl <- testEquality a a'
+ , Just Refl <- testEquality b b'
+ = Just Refl
+ testEquality (TFun a b) (TFun a' b')
+ | Just Refl <- testEquality a a'
+ , Just Refl <- testEquality b b'
+ = Just Refl
+ testEquality _ _ = Nothing
+
+class KnownType t where τ :: Typ t
+instance KnownType Int where τ = TInt
+instance KnownType Bool where τ = TBool
+instance (KnownType a, KnownType b) => KnownType (a, b) where τ = TPair τ τ
+instance (KnownType a, KnownType b) => KnownType (a -> b) where τ = TFun τ τ
+
+data PrimOp a b where
+ POAddI :: PrimOp (Int, Int) Int
+ POMulI :: PrimOp (Int, Int) Int
+ POEqI :: PrimOp (Int, Int) Bool
+deriving instance Show (PrimOp a b)
+
+opType2 :: PrimOp a b -> Typ b
+opType2 = \case
+ POAddI -> TInt
+ POMulI -> TInt
+ POEqI -> TBool
+
+data Fixity = Infix | Prefix
+ deriving (Show)
+
+primOpPrec :: PrimOp a b -> (Int, (Int, Int))
+primOpPrec POAddI = (6, (6, 7))
+primOpPrec POMulI = (7, (7, 8))
+primOpPrec POEqI = (4, (5, 5))
+
+prettyPrimOp :: Fixity -> PrimOp a b -> ShowS
+prettyPrimOp fix op =
+ let s = case op of
+ POAddI -> "+"
+ POMulI -> "*"
+ POEqI -> "=="
+ in showString $ case fix of
+ Infix -> s
+ Prefix -> "(" ++ s ++ ")"
+
+data ArithF r t where
+ A_Prim :: PrimOp a b -> r a -> ArithF r b
+ A_Pair :: r a -> r b -> ArithF r (a, b)
+ A_If :: r Bool -> r a -> r a -> ArithF r a
+deriving instance (forall a. Show (r a)) => Show (ArithF r t)
+
+instance Functor1 ArithF
+instance Traversable1 ArithF where
+ traverse1 f (A_Prim op x) = A_Prim op <$> f x
+ traverse1 f (A_Pair x y) = A_Pair <$> f x <*> f y
+ traverse1 f (A_If x y z) = A_If <$> f x <*> f y <*> f z
+
+prettyArithF :: Monad m
+ => (forall a. Int -> BExpr Typ env ArithF a -> m ShowS)
+ -> Int -> ArithF (BExpr Typ env ArithF) t -> m ShowS
+prettyArithF pr d = \case
+ A_Prim op (BOp _ (A_Pair a b)) -> do
+ let (dop, (dopL, dopR)) = primOpPrec op
+ a' <- pr dopL a
+ b' <- pr dopR b
+ return $ showParen (d > dop) $ a' . showString " " . prettyPrimOp Infix op . showString " " . b'
+ A_Prim op (BLet ty rhs e) ->
+ pr d (BLet ty rhs (BOp (opType2 op) (A_Prim op e)))
+ A_Prim op arg -> do
+ arg' <- pr 11 arg
+ return $ showParen (d > 10) $ prettyPrimOp Prefix op . showString " " . arg'
+ A_Pair a b -> do
+ a' <- pr 0 a
+ b' <- pr 0 b
+ return $ showString "(" . a' . showString ", " . b' . showString ")"
+ A_If a b c -> do
+ a' <- pr 0 a
+ b' <- pr 0 b
+ c' <- pr 0 c
+ return $ showParen (d > 0) $ showString "if " . a' . showString " then " . b' . showString " else " . c'
diff --git a/test/Main.hs b/test/Main.hs
index e7b303b..1a8d8e1 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -5,110 +5,46 @@
{-# LANGUAGE StandaloneDeriving #-}
module Main where
-import Data.Type.Equality
-
import Data.Expr.SharingRecovery
+import Arith
-data Typ t where
- TInt :: Typ Int
- TBool :: Typ Bool
- TPair :: Typ a -> Typ b -> Typ (a, b)
- TFun :: Typ a -> Typ b -> Typ (a -> b)
-deriving instance Show (Typ t)
-
-instance TestEquality Typ where
- testEquality TInt TInt = Just Refl
- testEquality TBool TBool = Just Refl
- testEquality (TPair a b) (TPair a' b')
- | Just Refl <- testEquality a a'
- , Just Refl <- testEquality b b'
- = Just Refl
- testEquality (TFun a b) (TFun a' b')
- | Just Refl <- testEquality a a'
- , Just Refl <- testEquality b b'
- = Just Refl
- testEquality _ _ = Nothing
-
-class KnownType t where τ :: Typ t
-instance KnownType Int where τ = TInt
-instance KnownType Bool where τ = TBool
-instance (KnownType a, KnownType b) => KnownType (a, b) where τ = TPair τ τ
-instance (KnownType a, KnownType b) => KnownType (a -> b) where τ = TFun τ τ
-data PrimOp a b where
- POAddI :: PrimOp (Int, Int) Int
- POMulI :: PrimOp (Int, Int) Int
- POEqI :: PrimOp (Int, Int) Bool
-deriving instance Show (PrimOp a b)
+-- TODO: test cyclic expressions
-data Fixity = Infix | Prefix
- deriving (Show)
-primOpPrec :: PrimOp a b -> (Int, (Int, Int))
-primOpPrec POAddI = (6, (6, 7))
-primOpPrec POMulI = (7, (7, 8))
-primOpPrec POEqI = (4, (5, 5))
+a_bin :: (KnownType a, KnownType b, KnownType c)
+ => PrimOp (a, b) c
+ -> PHOASExpr Typ v ArithF a
+ -> PHOASExpr Typ v ArithF b
+ -> PHOASExpr Typ v ArithF c
+a_bin op a b = PHOASOp τ (A_Prim op (PHOASOp τ (A_Pair a b)))
-prettyPrimOp :: Fixity -> PrimOp a b -> ShowS
-prettyPrimOp fix op =
- let s = case op of
- POAddI -> "+"
- POMulI -> "*"
- POEqI -> "=="
- in showString $ case fix of
- Infix -> s
- Prefix -> "(" ++ s ++ ")"
+lam :: (KnownType a, KnownType b)
+ => (PHOASExpr Typ v f a -> PHOASExpr Typ v f b) -> PHOASExpr Typ v f (a -> b)
+lam f = PHOASLam τ τ $ \arg -> f (PHOASVar τ arg)
-data ArithF r t where
- A_Prim :: PrimOp a b -> r a -> ArithF r b
- A_Pair :: r a -> r b -> ArithF r (a, b)
- A_If :: r Bool -> r a -> r a -> ArithF r a
-deriving instance (forall a. Show (r a)) => Show (ArithF r t)
+(+!) :: PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int
+(+!) = a_bin POAddI
-instance Functor1 ArithF
-instance Traversable1 ArithF where
- traverse1 f (A_Prim op x) = A_Prim op <$> f x
- traverse1 f (A_Pair x y) = A_Pair <$> f x <*> f y
- traverse1 f (A_If x y z) = A_If <$> f x <*> f y <*> f z
-
-prettyArithF :: Monad m
- => (forall a. Int -> BExpr typ env ArithF a -> m ShowS)
- -> Int -> ArithF (BExpr typ env ArithF) t -> m ShowS
-prettyArithF pr d = \case
- A_Prim op (BOp _ (A_Pair a b)) -> do
- let (dop, (dopL, dopR)) = primOpPrec op
- a' <- pr dopL a
- b' <- pr dopR b
- return $ showParen (d > dop) $ a' . showString " " . prettyPrimOp Infix op . showString " " . b'
- A_Prim op arg -> do
- arg' <- pr 11 arg
- return $ showParen (d > 10) $ prettyPrimOp Prefix op . showString " " . arg'
- A_Pair a b -> do
- a' <- pr 0 a
- b' <- pr 0 b
- return $ showString "(" . a' . showString ", " . b' . showString ")"
- A_If a b c -> do
- a' <- pr 0 a
- b' <- pr 0 b
- c' <- pr 0 c
- return $ showParen (d > 0) $ showString "if " . a' . showString " then " . b' . showString " else " . c'
+(*!) :: PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int
+(*!) = a_bin POMulI
-- λx. x + x
ea_1 :: PHOASExpr Typ v ArithF (Int -> Int)
-ea_1 =
- PHOASLam τ τ $ \arg ->
- PHOASOp τ (A_Prim POAddI
- (PHOASOp τ (A_Pair (PHOASVar τ arg) (PHOASVar τ arg))))
+ea_1 = lam $ \arg -> arg +! arg
-- λx. let y = x + x in y * y
ea_2 :: PHOASExpr Typ v ArithF (Int -> Int)
-ea_2 =
- PHOASLam τ τ $ \arg ->
- let y = PHOASOp τ (A_Prim POAddI
- (PHOASOp τ (A_Pair (PHOASVar τ arg) (PHOASVar τ arg))))
- in PHOASOp τ (A_Prim POMulI
- (PHOASOp τ (A_Pair y y)))
+ea_2 = lam $ \arg -> let y = arg +! arg
+ in y *! y
+
+ea_3 :: PHOASExpr Typ v ArithF (Int -> Int)
+ea_3 = lam $ \arg ->
+ let y = arg +! arg
+ x = y *! arg
+ -- in (y +! x) +! (x +! y)
+ in (x +! y) +! (y +! x)
main :: IO ()
-main = putStrLn $ prettyBExpr prettyArithF (sharingRecovery ea_2)
+main = putStrLn $ prettyBExpr prettyArithF (sharingRecovery ea_3)