aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sharing-recovery.cabal13
-rw-r--r--src/Data/Expr/SharingRecovery.hs3
-rw-r--r--src/Data/Expr/SharingRecovery/Internal.hs217
-rw-r--r--test-th/Arith.hs (renamed from test/Arith.hs)0
-rw-r--r--test-th/Arith/NonBase.hs (renamed from test/Arith/NonBase.hs)0
-rw-r--r--test-th/Main.hs51
-rw-r--r--test-th/NonBaseTH.hs (renamed from test/NonBaseTH.hs)0
-rw-r--r--test/Main.hs113
8 files changed, 267 insertions, 130 deletions
diff --git a/sharing-recovery.cabal b/sharing-recovery.cabal
index 9e34bab..162328d 100644
--- a/sharing-recovery.cabal
+++ b/sharing-recovery.cabal
@@ -25,10 +25,21 @@ test-suite test
type: exitcode-stdio-1.0
main-is: Main.hs
other-modules:
+ hs-source-dirs: test
+ build-depends:
+ sharing-recovery,
+ base
+ default-language: Haskell2010
+ ghc-options: -Wall
+
+test-suite test-th
+ type: exitcode-stdio-1.0
+ main-is: Main.hs
+ other-modules:
Arith
Arith.NonBase
NonBaseTH
- hs-source-dirs: test
+ hs-source-dirs: test-th
build-depends:
sharing-recovery,
base,
diff --git a/src/Data/Expr/SharingRecovery.hs b/src/Data/Expr/SharingRecovery.hs
index 02b3e3e..95524dd 100644
--- a/src/Data/Expr/SharingRecovery.hs
+++ b/src/Data/Expr/SharingRecovery.hs
@@ -1,10 +1,13 @@
module Data.Expr.SharingRecovery (
-- * Sharing recovery
sharingRecovery,
+ sharingRecoveryUnsafe,
-- * Expressions
PHOASExpr(..),
+ typeOfPHOAS,
BExpr(..),
+ typeOfBExpr,
Idx(..),
-- * Traversing indexed structures
diff --git a/src/Data/Expr/SharingRecovery/Internal.hs b/src/Data/Expr/SharingRecovery/Internal.hs
index 0089454..9d00355 100644
--- a/src/Data/Expr/SharingRecovery/Internal.hs
+++ b/src/Data/Expr/SharingRecovery/Internal.hs
@@ -78,7 +78,10 @@ class Functor1 f => Traversable1 f where
--
-- * @v@ is the type of variables in the expression. A PHOAS expression is
-- required to be parametric in the @v@ parameter; the only place you will
--- obtain a @v@ is inside a @PHOASLam@ function body.
+-- obtain a @v@ is inside a @PHOASLam@ function body. Even if you use
+-- 'sharingRecoveryUnsafe' which allows you to pass in a monomorphic
+-- expression with 'Tag', it is expressly disallowed to inspect or
+-- manipulate 'Tag' values.
--
-- * @f@ should be your type of operations for your language. It is indexed by
-- the type of subexpressions and the result type of the operation; thus, it
@@ -100,7 +103,7 @@ class Functor1 f => Traversable1 f where
--
-- Note furthermore that @Oper@ is /not/ a recursive type. Subexpressions
-- are again 'PHOASExpr's, and 'sharingRecovery' needs to be able to see
--- them. Hence, you should call back to back to @r@ instead of recursing
+-- them. Hence, you should call back to @r@ instead of recursing
-- manually.
--
-- * @t@ is the result type of this expression.
@@ -109,6 +112,12 @@ data PHOASExpr typ v f t where
PHOASLam :: typ (a -> b) -> typ a -> (v a -> PHOASExpr typ v f b) -> PHOASExpr typ v f (a -> b)
PHOASVar :: typ t -> v t -> PHOASExpr typ v f t
+-- | Tag values identify a variable, and are created in the sharing recovery
+-- process (in 'pruneExpr'). Applying the sharing recovery algorithm to a term
+-- with manually constructed Tag values, or inspecting Tag values in the
+-- expression and computing with them or branching based on them, results in
+-- "impossible terms" and behaviour of the sharing recovery algorithm is
+-- undefined. Consider Tag an opaque value.
newtype Tag t = Tag Natural
deriving (Show, Eq)
deriving (Hashable) via Natural
@@ -125,16 +134,26 @@ instance TestEquality (NameFor typ f) where
unsafeCoerceRefl :: a :~: b -- restricted version of unsafeCoerce that only allows punting proofs
unsafeCoerceRefl = unsafeCoerce Refl
+typeOfPHOAS :: PHOASExpr typ v f t -> typ t
+typeOfPHOAS (PHOASOp ty _) = ty
+typeOfPHOAS (PHOASLam ty _ _) = ty
+typeOfPHOAS (PHOASVar ty _) = ty
+
-- | Pruned expression.
--
-- Note that variables do not, and will never, have a name: we don't bother
-- detecting sharing for variable references, because that would only introduce
-- a redundant variable indirection.
-data PExpr typ f t where
- PStub :: NameFor typ f t -> typ t -> PExpr typ f t
- POp :: NameFor typ f t -> typ t -> f (PExpr typ f) t -> PExpr typ f t
- PLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> PExpr typ f b -> PExpr typ f (a -> b)
- PVar :: typ a -> Tag a -> PExpr typ f a
+--
+-- This is defined as a base functor; @r@ is the recursive position.
+data PExpr r typ f t where
+ PStub :: NameFor typ f t -> typ t -> PExpr r typ f t
+ POp :: NameFor typ f t -> typ t -> f (r typ f) t -> PExpr r typ f t
+ PLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> r typ f b -> PExpr r typ f (a -> b)
+ PVar :: typ a -> Tag a -> PExpr r typ f a
+
+-- | Fixpoint of 'PExpr'
+newtype PExpr0 typ f t = PExpr0 (PExpr PExpr0 typ f t)
data SomeNameFor typ f = forall t. SomeNameFor {-# UNPACK #-} !(NameFor typ f t)
@@ -144,19 +163,22 @@ instance Eq (SomeNameFor typ f) where
instance Hashable (SomeNameFor typ f) where
hashWithSalt salt (SomeNameFor name) = hashWithSalt salt name
-prettyPExpr :: Traversable1 f => Int -> PExpr typ f t -> ShowS
-prettyPExpr d = \case
+prettyPExpr0 :: Traversable1 f => Int -> PExpr0 typ f t -> ShowS
+prettyPExpr0 d (PExpr0 ex) = prettyPExpr prettyPExpr0 d ex
+
+prettyPExpr :: Traversable1 f => (forall a. Int -> r typ f a -> ShowS) -> Int -> PExpr r typ f t -> ShowS
+prettyPExpr recur d = \case
PStub (NameFor name) _ -> showString (showStableName name)
POp (NameFor name) _ args ->
let (argslist, _) = traverse1 (\arg -> ([Some arg], Const ())) args
- argslist' = map (\(Some arg) -> prettyPExpr 0 arg) argslist
+ argslist' = map (\(Some arg) -> recur 0 arg) argslist
in showParen (d > 10) $
showString ("<" ++ showStableName name ++ ">(")
. foldr (.) id (intersperse (showString ", ") argslist')
. showString ")"
PLam (NameFor name) _ _ (Tag tag) body ->
showParen (d > 0) $
- showString ("λ" ++ showStableName name ++ " x" ++ show tag ++ ". ") . prettyPExpr 0 body
+ showString ("λ" ++ showStableName name ++ " x" ++ show tag ++ ". ") . recur 0 body
PVar _ (Tag tag) -> showString ("x" ++ show tag)
-- | For each name:
@@ -170,13 +192,19 @@ prettyPExpr d = \case
-- Missing names have not been seen yet, and have unknown height.
type OccMap typ f = HashMap (SomeNameFor typ f) (Natural, Natural)
-pruneExpr :: Traversable1 f => (forall v. PHOASExpr typ v f t) -> (OccMap typ f, PExpr typ f t)
+pruneExpr :: Traversable1 f => (forall v. PHOASExpr typ v f t) -> (OccMap typ f, PExpr0 typ f t)
pruneExpr term =
let ((term', _), (_, mp)) = runState (pruneExpr' term) (0, mempty)
in (mp, term')
+pruneExprUnsafe :: Traversable1 f => PHOASExpr typ Tag f t -> (OccMap typ f, PExpr0 typ f t)
+pruneExprUnsafe term =
+ let ((term', _), (_, mp)) = runState (pruneExpr' term) (0, mempty)
+ in (mp, term')
+
-- | Returns pruned expression with its height.
-pruneExpr' :: Traversable1 f => PHOASExpr typ Tag f t -> State (Natural, OccMap typ f) (PExpr typ f t, Natural)
+-- State: (ID generator, occurrence map being accumulated)
+pruneExpr' :: Traversable1 f => PHOASExpr typ Tag f t -> State (Natural, OccMap typ f) (PExpr0 typ f t, Natural)
pruneExpr' = \case
orig@(PHOASOp ty args) -> do
let name = makeStableName' orig
@@ -185,7 +213,7 @@ pruneExpr' = \case
-- already visited
Just height -> do
modify (second (HM.adjust (first (+1)) (SomeNameFor (NameFor name))))
- pure (PStub (NameFor name) ty, height)
+ pure (PExpr0 (PStub (NameFor name) ty), height)
-- first visit
Nothing -> do
-- Traverse the arguments, collecting the maximum height in an
@@ -193,13 +221,14 @@ pruneExpr' = \case
(args', maxhei) <-
withMoreState 0 $
traverse1 (\arg -> do
+ -- drop the extra state for the recursive call
(arg', hei) <- withLessState id (,) (pruneExpr' arg)
- modify (second (hei `max`))
+ modify (second (hei `max`)) -- modify the extra state
return arg')
args
-- Record this node
modify (second (HM.insert (SomeNameFor (NameFor name)) (1, 1 + maxhei)))
- pure (POp (NameFor name) ty args', 1 + maxhei)
+ pure (PExpr0 (POp (NameFor name) ty args'), 1 + maxhei)
orig@(PHOASLam tyf tyarg f) -> do
let name = makeStableName' orig
@@ -208,7 +237,7 @@ pruneExpr' = \case
-- already visited
Just height -> do
modify (second (HM.adjust (first (+1)) (SomeNameFor (NameFor name))))
- pure (PStub (NameFor name) tyf, height)
+ pure (PExpr0 (PStub (NameFor name) tyf), height)
-- first visit
Nothing -> do
tag <- Tag <$> gets fst
@@ -216,45 +245,24 @@ pruneExpr' = \case
let body = f tag
(body', bodyhei) <- pruneExpr' body
modify (second (HM.insert (SomeNameFor (NameFor name)) (1, 1 + bodyhei)))
- pure (PLam (NameFor name) tyf tyarg tag body', 1 + bodyhei)
+ pure (PExpr0 (PLam (NameFor name) tyf tyarg tag body'), 1 + bodyhei)
- PHOASVar ty tag -> pure (PVar ty tag, 1)
+ PHOASVar ty tag -> pure (PExpr0 (PVar ty tag), 1)
--- | Floated expression: a bunch of to-be let bound expressions on top of an
--- LExpr'. Because LExpr' is really just PExpr with the recursive positions
--- replaced by LExpr, LExpr should be seen as PExpr with a bunch of to-be let
--- bound expressions on top of every node.
-data LExpr typ f t = LExpr [Some (LExpr typ f)] (LExpr' typ f t)
-data LExpr' typ f t where -- TODO: this could be an instantiation of (a generalisation of) PExpr
- LStub :: NameFor typ f t -> typ t -> LExpr' typ f t
- LOp :: NameFor typ f t -> typ t -> f (LExpr typ f) t -> LExpr' typ f t
- LLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> LExpr typ f b -> LExpr' typ f (a -> b)
- LVar :: typ a -> Tag a -> LExpr' typ f a
+-- | Floated expression: again a 'PExpr' (it's a fixpoint over the same base
+-- functor), but now with a bunch of to-be let bound expressions on top of
+-- every node.
+data LExpr typ f t = LExpr [Some (LExpr typ f)] (PExpr LExpr typ f t)
prettyLExpr :: Traversable1 f => Int -> LExpr typ f t -> ShowS
-prettyLExpr d (LExpr [] e) = prettyLExpr' d e
+prettyLExpr d (LExpr [] e) = prettyPExpr prettyLExpr d e
prettyLExpr d (LExpr floated e) =
showString "["
. foldr (.) id (intersperse (showString ", ") (map (\(Some e') -> prettyLExpr 0 e') floated))
- . showString "] " . prettyLExpr' d e
-
-prettyLExpr' :: Traversable1 f => Int -> LExpr' typ f t -> ShowS
-prettyLExpr' d = \case
- LStub (NameFor name) _ -> showString (showStableName name)
- LOp (NameFor name) _ args ->
- let (argslist, _) = traverse1 (\arg -> ([Some arg], Const ())) args
- argslist' = map (\(Some arg) -> prettyLExpr 0 arg) argslist
- in showParen (d > 10) $
- showString ("<" ++ showStableName name ++ ">(")
- . foldr (.) id (intersperse (showString ", ") argslist')
- . showString ")"
- LLam (NameFor name) _ _ (Tag tag) body ->
- showParen (d > 0) $
- showString ("λ" ++ showStableName name ++ " x" ++ show tag ++ ". ") . prettyLExpr 0 body
- LVar _ (Tag tag) -> showString ("x" ++ show tag)
+ . showString "] " . prettyPExpr prettyLExpr d e
-floatExpr :: Traversable1 f => OccMap typ f -> PExpr typ f t -> LExpr typ f t
+floatExpr :: Traversable1 f => OccMap typ f -> PExpr0 typ f t -> LExpr typ f t
floatExpr totals term = snd (floatExpr' totals term)
newtype FoundMap typ f = FoundMap
@@ -269,46 +277,47 @@ instance Semigroup (FoundMap typ f) where
instance Monoid (FoundMap typ f) where
mempty = FoundMap HM.empty
-floatExpr' :: Traversable1 f => OccMap typ f -> PExpr typ f t -> (FoundMap typ f, LExpr typ f t)
-floatExpr' _totals (PStub name ty) =
- -- trace ("Found stub: " ++ (case name of NameFor n -> showStableName n)) $
- (FoundMap $ HM.singleton (SomeNameFor name) (1, Nothing)
- ,LExpr [] (LStub name ty))
+floatExpr' :: Traversable1 f => OccMap typ f -> PExpr0 typ f t -> (FoundMap typ f, LExpr typ f t)
+floatExpr' totals (PExpr0 term) = case term of
+ PStub name ty ->
+ -- trace ("Found stub: " ++ (case name of NameFor n -> showStableName n)) $
+ (FoundMap $ HM.singleton (SomeNameFor name) (1, Nothing)
+ ,LExpr [] (PStub name ty))
-floatExpr' _totals (PVar ty tag) =
- -- trace ("Found var: " ++ show tag) $
- (mempty, LExpr [] (LVar ty tag))
+ PVar ty tag ->
+ -- trace ("Found var: " ++ show tag) $
+ (mempty, LExpr [] (PVar ty tag))
-floatExpr' totals term =
- let (FoundMap foundmap, name, termty, term') = case term of
- POp n ty args ->
- let (fm, args') = traverse1 (floatExpr' totals) args
- in (fm, n, ty, LOp n ty args')
- PLam n tyf tyarg tag body ->
- let (fm, body') = floatExpr' totals body
- in (fm, n, tyf, LLam n tyf tyarg tag body')
+ _ ->
+ let (FoundMap foundmap, name, termty, term') = case term of
+ POp n ty args ->
+ let (fm, args') = traverse1 (floatExpr' totals) args
+ in (fm, n, ty, POp n ty args')
+ PLam n tyf tyarg tag body ->
+ let (fm, body') = floatExpr' totals body
+ in (fm, n, tyf, PLam n tyf tyarg tag body')
- -- TODO: perhaps this HM.toList together with the foldr HM.delete can be a single traversal of the HashMap
- saturated = [case mterm of
- Just t -> (nm, t)
- Nothing -> case nm of
- SomeNameFor (NameFor n) ->
- error $ "Name saturated (count=" ++ show count ++ ", totalcount=" ++ show totalcount ++ ") but no term found: " ++ showStableName n
- | (nm, (count, mterm)) <- HM.toList foundmap
- , let totalcount = fromMaybe 0 (fst <$> HM.lookup nm totals)
- , count == totalcount]
+ -- TODO: perhaps this HM.toList together with the foldr HM.delete can be a single traversal of the HashMap
+ saturated = [case mterm of
+ Just t -> (nm, t)
+ Nothing -> case nm of
+ SomeNameFor (NameFor n) ->
+ error $ "Name saturated (count=" ++ show count ++ ", totalcount=" ++ show totalcount ++ ") but no term found: " ++ showStableName n
+ | (nm, (count, mterm)) <- HM.toList foundmap
+ , let totalcount = fromMaybe 0 (fst <$> HM.lookup nm totals)
+ , count == totalcount]
- foundmap' = foldr HM.delete foundmap (map fst saturated)
+ foundmap' = foldr HM.delete foundmap (map fst saturated)
- lterm = LExpr (map fst (sortBy (comparing snd) (map snd saturated))) term'
+ lterm = LExpr (map fst (sortBy (comparing snd) (map snd saturated))) term'
- in case HM.findWithDefault (0, undefined) (SomeNameFor name) totals of
- (1, _) -> (FoundMap foundmap', lterm)
- (tot, height)
- | tot > 1 -> -- trace ("Inserting " ++ (case name of NameFor n -> showStableName n) ++ " into foundmap") $
- (FoundMap (HM.insert (SomeNameFor name) (1, Just (Some lterm, height)) foundmap')
- ,LExpr [] (LStub name termty))
- | otherwise -> error "Term does not exist, yet we have it in hand"
+ in case HM.findWithDefault (0, undefined) (SomeNameFor name) totals of
+ (1, _) -> (FoundMap foundmap', lterm)
+ (tot, height)
+ | tot > 1 -> -- trace ("Inserting " ++ (case name of NameFor n -> showStableName n) ++ " into foundmap") $
+ (FoundMap (HM.insert (SomeNameFor name) (1, Just (Some lterm, height)) foundmap')
+ ,LExpr [] (PStub name termty))
+ | otherwise -> error "Term does not exist, yet we have it in hand"
-- | Untyped De Bruijn expression. No more names: there are lets now, and
@@ -342,15 +351,15 @@ lowerExpr' namelvl taglvl curlvl (LExpr floated ex) =
curlvl' = curlvl + length floated
in prefix $
case ex of
- LStub name ty ->
+ PStub name ty ->
case HM.lookup (SomeNameFor name) namelvl' of
Just lvl -> UBVar ty (curlvl - lvl - 1)
Nothing -> error "Name variable out of scope"
- LOp _ ty args ->
+ POp _ ty args ->
UBOp ty (fmap1 (lowerExpr' namelvl' taglvl curlvl') args)
- LLam _ tyf tyarg tag body ->
+ PLam _ tyf tyarg tag body ->
UBLam tyf tyarg (lowerExpr' namelvl' (HM.insert (SomeTag tag) curlvl' taglvl) (curlvl' + 1) body)
- LVar ty tag ->
+ PVar ty tag ->
case HM.lookup (SomeTag tag) taglvl of
Just lvl -> UBVar ty (curlvl - lvl - 1)
Nothing -> error "Tag variable out of scope"
@@ -363,17 +372,17 @@ lowerExpr' namelvl taglvl curlvl (LExpr floated ex) =
buildPrefix namelvl' _ [] = (namelvl', id)
buildPrefix namelvl' lvl (Some rhs@(LExpr _ rhs') : rhss) =
let name = case rhs' of
- LStub n _ -> n
- LOp n _ _ -> n
- LLam n _ _ _ _ -> n
- LVar _ _ -> error "Recovering sharing of a tag is useless"
+ PStub n _ -> n
+ POp n _ _ -> n
+ PLam n _ _ _ _ -> n
+ PVar _ _ -> error "Recovering sharing of a tag is useless"
ty = case rhs' of
- LStub{} -> error "Recovering sharing of a stub is useless"
- LOp _ t _ -> t
- LLam _ t _ _ _ -> t
- LVar t _ -> t
+ PStub{} -> error "Recovering sharing of a stub is useless"
+ POp _ t _ -> t
+ PLam _ t _ _ _ -> t
+ PVar t _ -> t
prefix = UBLet ty (lowerExpr' namelvl' taglvl lvl rhs)
- in (prefix .) <$> buildPrefix (HM.insert (SomeNameFor name) lvl namelvl') (lvl + 1) rhss
+ in second (prefix .) $ buildPrefix (HM.insert (SomeNameFor name) lvl namelvl') (lvl + 1) rhss
-- | A typed De Bruijn index.
@@ -399,7 +408,7 @@ envLookupU = go id
go f 0 (EPush _ t) = Just (Some (Pair t (f IZ)))
go f i (EPush e _) = go (f . IS) (i - 1) e
--- | Typed De Bruijn expression. This is the resu,t of sharing recovery. It is
+-- | Typed De Bruijn expression. This is the result of sharing recovery. It is
-- not higher-order any more, and furthermore has explicit let-bindings ('BLet')
-- that denote the sharing inside the term. This is a normal AST.
--
@@ -418,6 +427,12 @@ data BExpr typ env f t where
deriving instance (forall a. Show (typ a), forall a r. (forall b. Show (r b)) => Show (f r a))
=> Show (BExpr typ env f t)
+typeOfBExpr :: BExpr typ v f t -> typ t
+typeOfBExpr (BOp ty _) = ty
+typeOfBExpr (BLam ty _ _) = ty
+typeOfBExpr (BLet _ _ e) = typeOfBExpr e
+typeOfBExpr (BVar ty _) = ty
+
prettyBExpr :: (forall m env' a. Monad m => (forall b. Int -> BExpr typ env' f b -> m ShowS)
-> Int -> f (BExpr typ env' f) a -> m ShowS)
-> BExpr typ '[] f t -> String
@@ -466,8 +481,16 @@ retypeExpr' env (UBVar ty idx) =
-- abstract syntax form to a well-typed well-scoped De Bruijn expression with
-- explicit let-bindings.
sharingRecovery :: (Traversable1 f, TestEquality typ) => (forall v. PHOASExpr typ v f t) -> BExpr typ '[] f t
-sharingRecovery e =
- let (occmap, pexpr) = pruneExpr e
+sharingRecovery = sharingRecoveryUnsafe
+
+-- | The 'sharingRecovery' function instantiates the @v@ parameter to 'Tag',
+-- and this function is provided as a convenience in case constructing a
+-- polymorphic expression value is difficult. However, it is /disallowed/ to
+-- inspect or manipulate 'Tag' values obtained by lambda abstraction inside the
+-- expression. Violating this rule results in undefined behaviour.
+sharingRecoveryUnsafe :: (Traversable1 f, TestEquality typ) => PHOASExpr typ Tag f t -> BExpr typ '[] f t
+sharingRecoveryUnsafe e =
+ let (occmap, pexpr) = pruneExprUnsafe e
lexpr = floatExpr occmap pexpr
ubexpr = lowerExpr lexpr
in -- trace ("PExpr: " ++ prettyPExpr 0 pexpr "") $
diff --git a/test/Arith.hs b/test-th/Arith.hs
index c34baa8..c34baa8 100644
--- a/test/Arith.hs
+++ b/test-th/Arith.hs
diff --git a/test/Arith/NonBase.hs b/test-th/Arith/NonBase.hs
index f5d458e..f5d458e 100644
--- a/test/Arith/NonBase.hs
+++ b/test-th/Arith/NonBase.hs
diff --git a/test-th/Main.hs b/test-th/Main.hs
new file mode 100644
index 0000000..5a4d335
--- /dev/null
+++ b/test-th/Main.hs
@@ -0,0 +1,51 @@
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+module Main where
+
+import Data.Expr.SharingRecovery
+import Data.Expr.SharingRecovery.Internal
+
+import Arith
+
+
+-- TODO: test cyclic expressions
+
+
+a_bin :: (KnownType a, KnownType b, KnownType c)
+ => PrimOp (a, b) c
+ -> PHOASExpr Typ v ArithF a
+ -> PHOASExpr Typ v ArithF b
+ -> PHOASExpr Typ v ArithF c
+a_bin op a b = PHOASOp τ (A_Prim op (PHOASOp τ (A_Pair a b)))
+
+lam :: (KnownType a, KnownType b)
+ => (PHOASExpr Typ v f a -> PHOASExpr Typ v f b) -> PHOASExpr Typ v f (a -> b)
+lam f = PHOASLam τ τ $ \arg -> f (PHOASVar τ arg)
+
+(+!) :: PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int
+(+!) = a_bin POAddI
+
+(*!) :: PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int
+(*!) = a_bin POMulI
+
+-- λx. x + x
+ea_1 :: PHOASExpr Typ v ArithF (Int -> Int)
+ea_1 = lam $ \arg -> arg +! arg
+
+-- λx. let y = x + x in y * y
+ea_2 :: PHOASExpr Typ v ArithF (Int -> Int)
+ea_2 = lam $ \arg -> let y = arg +! arg
+ in y *! y
+
+ea_3 :: PHOASExpr Typ v ArithF (Int -> Int)
+ea_3 = lam $ \arg ->
+ let y = arg +! arg
+ x = y *! arg
+ -- in (y +! x) +! (x +! y)
+ in (x +! y) +! (y +! x)
+
+main :: IO ()
+main = putStrLn $ prettyBExpr prettyArithF (sharingRecovery ea_3)
diff --git a/test/NonBaseTH.hs b/test-th/NonBaseTH.hs
index 4741ea0..4741ea0 100644
--- a/test/NonBaseTH.hs
+++ b/test-th/NonBaseTH.hs
diff --git a/test/Main.hs b/test/Main.hs
index 5a4d335..208d3d6 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -1,51 +1,100 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE QuantifiedConstraints #-}
-{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE QuantifiedConstraints #-}
module Main where
import Data.Expr.SharingRecovery
-import Data.Expr.SharingRecovery.Internal
+import Data.Type.Equality
+
+
+data Ty a where
+ TInt :: Ty Int
+ TFloat :: Ty Float
+ TBool :: Ty Bool
+deriving instance Show (Ty a)
-import Arith
+instance TestEquality Ty where
+ testEquality TInt TInt = Just Refl
+ testEquality TInt _ = Nothing
+ testEquality TFloat TFloat = Just Refl
+ testEquality TFloat _ = Nothing
+ testEquality TBool TBool = Just Refl
+ testEquality TBool _ = Nothing
+type family IsOrdTy a where
+ IsOrdTy Int = True
+ IsOrdTy Float = True
+ IsOrdTy _ = False
--- TODO: test cyclic expressions
+data Unop a b where
+ UONeg :: Ty a -> Unop a a
+ UONot :: Unop Bool Bool
+deriving instance Show (Unop a b)
+data Binop a b c where
+ BOAdd :: Ty a -> Binop a a a
+ BOSub :: Ty a -> Binop a a a
+ BOMul :: Ty a -> Binop a a a
+ BOAnd :: Binop Bool Bool Bool
+ BOOr :: Binop Bool Bool Bool
+ BOLt :: IsOrdTy a ~ True => Ty a -> Binop a a Bool
+ BOLeq :: IsOrdTy a ~ True => Ty a -> Binop a a Bool
+ BOEq :: IsOrdTy a ~ True => Ty a -> Binop a a Bool
+ BONeq :: IsOrdTy a ~ True => Ty a -> Binop a a Bool
+deriving instance Show (Binop a b c)
-a_bin :: (KnownType a, KnownType b, KnownType c)
- => PrimOp (a, b) c
- -> PHOASExpr Typ v ArithF a
- -> PHOASExpr Typ v ArithF b
- -> PHOASExpr Typ v ArithF c
-a_bin op a b = PHOASOp τ (A_Prim op (PHOASOp τ (A_Pair a b)))
+data Lang r a where
+ Un :: Unop a b -> r a -> Lang r b
+ Bin :: Binop a b c -> r a -> r b -> Lang r c
+ Cond :: r Bool -> r a -> r a -> Lang r a
+ Cnst :: Show a => a -> Lang r a -- there's a type in the BExpr in the end, no need for one here
+deriving instance (forall b. Show (r b)) => Show (Lang r a)
-lam :: (KnownType a, KnownType b)
- => (PHOASExpr Typ v f a -> PHOASExpr Typ v f b) -> PHOASExpr Typ v f (a -> b)
-lam f = PHOASLam τ τ $ \arg -> f (PHOASVar τ arg)
+instance Functor1 Lang
+instance Traversable1 Lang where
+ traverse1 f = \case
+ Un op x -> Un op <$> f x
+ Bin op x y -> Bin op <$> f x <*> f y
+ Cond x y z -> Cond <$> f x <*> f y <*> f z
+ Cnst v -> pure (Cnst v)
-(+!) :: PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int
-(+!) = a_bin POAddI
+class KnownTy a where knownTy :: Ty a
+instance KnownTy Int where knownTy = TInt
+instance KnownTy Float where knownTy = TFloat
+instance KnownTy Bool where knownTy = TBool
-(*!) :: PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int
-(*!) = a_bin POMulI
+type Expr v = PHOASExpr Ty v Lang
--- λx. x + x
-ea_1 :: PHOASExpr Typ v ArithF (Int -> Int)
-ea_1 = lam $ \arg -> arg +! arg
+cond :: KnownTy a => Expr v Bool -> Expr v a -> Expr v a -> Expr v a
+cond a b c = PHOASOp knownTy (Cond a b c)
--- λx. let y = x + x in y * y
-ea_2 :: PHOASExpr Typ v ArithF (Int -> Int)
-ea_2 = lam $ \arg -> let y = arg +! arg
- in y *! y
+(.<), (.<=), (.>), (.>=) :: (KnownTy a, IsOrdTy a ~ True) => Expr v a -> Expr v a -> Expr v Bool
+a .< b = PHOASOp TBool (Bin (BOLt knownTy) a b)
+a .<= b = PHOASOp TBool (Bin (BOLeq knownTy) a b)
+(.>) = flip (.<)
+(.>=) = flip (.<=)
+infix 4 .<
+infix 4 .<=
+infix 4 .>
+infix 4 .>=
-ea_3 :: PHOASExpr Typ v ArithF (Int -> Int)
-ea_3 = lam $ \arg ->
- let y = arg +! arg
- x = y *! arg
- -- in (y +! x) +! (x +! y)
- in (x +! y) +! (y +! x)
+instance (KnownTy a, IsOrdTy a ~ True, Num a, Show a) => Num (Expr v a) where
+ a + b = PHOASOp knownTy (Bin (BOAdd knownTy) a b)
+ a - b = PHOASOp knownTy (Bin (BOSub knownTy) a b)
+ a * b = PHOASOp knownTy (Bin (BOMul knownTy) a b)
+ negate a = PHOASOp knownTy (Un (UONeg knownTy) a)
+ abs a = cond (a .< 0) (-a) a
+ signum a = cond (a .< 0) (-1) (cond (a .> 0) 1 0)
+ fromInteger n = PHOASOp knownTy (Cnst (fromInteger n))
main :: IO ()
-main = putStrLn $ prettyBExpr prettyArithF (sharingRecovery ea_3)
+main = do
+ print $ sharingRecovery @Lang @_ $
+ let a = 2 ; b = 3 :: Expr v Int
+ in a + b .< b + a