diff options
| -rw-r--r-- | sharing-recovery.cabal | 13 | ||||
| -rw-r--r-- | src/Data/Expr/SharingRecovery.hs | 3 | ||||
| -rw-r--r-- | src/Data/Expr/SharingRecovery/Internal.hs | 217 | ||||
| -rw-r--r-- | test-th/Arith.hs (renamed from test/Arith.hs) | 0 | ||||
| -rw-r--r-- | test-th/Arith/NonBase.hs (renamed from test/Arith/NonBase.hs) | 0 | ||||
| -rw-r--r-- | test-th/Main.hs | 51 | ||||
| -rw-r--r-- | test-th/NonBaseTH.hs (renamed from test/NonBaseTH.hs) | 0 | ||||
| -rw-r--r-- | test/Main.hs | 113 |
8 files changed, 267 insertions, 130 deletions
diff --git a/sharing-recovery.cabal b/sharing-recovery.cabal index 9e34bab..162328d 100644 --- a/sharing-recovery.cabal +++ b/sharing-recovery.cabal @@ -25,10 +25,21 @@ test-suite test type: exitcode-stdio-1.0 main-is: Main.hs other-modules: + hs-source-dirs: test + build-depends: + sharing-recovery, + base + default-language: Haskell2010 + ghc-options: -Wall + +test-suite test-th + type: exitcode-stdio-1.0 + main-is: Main.hs + other-modules: Arith Arith.NonBase NonBaseTH - hs-source-dirs: test + hs-source-dirs: test-th build-depends: sharing-recovery, base, diff --git a/src/Data/Expr/SharingRecovery.hs b/src/Data/Expr/SharingRecovery.hs index 02b3e3e..95524dd 100644 --- a/src/Data/Expr/SharingRecovery.hs +++ b/src/Data/Expr/SharingRecovery.hs @@ -1,10 +1,13 @@ module Data.Expr.SharingRecovery ( -- * Sharing recovery sharingRecovery, + sharingRecoveryUnsafe, -- * Expressions PHOASExpr(..), + typeOfPHOAS, BExpr(..), + typeOfBExpr, Idx(..), -- * Traversing indexed structures diff --git a/src/Data/Expr/SharingRecovery/Internal.hs b/src/Data/Expr/SharingRecovery/Internal.hs index 0089454..9d00355 100644 --- a/src/Data/Expr/SharingRecovery/Internal.hs +++ b/src/Data/Expr/SharingRecovery/Internal.hs @@ -78,7 +78,10 @@ class Functor1 f => Traversable1 f where -- -- * @v@ is the type of variables in the expression. A PHOAS expression is -- required to be parametric in the @v@ parameter; the only place you will --- obtain a @v@ is inside a @PHOASLam@ function body. +-- obtain a @v@ is inside a @PHOASLam@ function body. Even if you use +-- 'sharingRecoveryUnsafe' which allows you to pass in a monomorphic +-- expression with 'Tag', it is expressly disallowed to inspect or +-- manipulate 'Tag' values. -- -- * @f@ should be your type of operations for your language. It is indexed by -- the type of subexpressions and the result type of the operation; thus, it @@ -100,7 +103,7 @@ class Functor1 f => Traversable1 f where -- -- Note furthermore that @Oper@ is /not/ a recursive type. Subexpressions -- are again 'PHOASExpr's, and 'sharingRecovery' needs to be able to see --- them. Hence, you should call back to back to @r@ instead of recursing +-- them. Hence, you should call back to @r@ instead of recursing -- manually. -- -- * @t@ is the result type of this expression. @@ -109,6 +112,12 @@ data PHOASExpr typ v f t where PHOASLam :: typ (a -> b) -> typ a -> (v a -> PHOASExpr typ v f b) -> PHOASExpr typ v f (a -> b) PHOASVar :: typ t -> v t -> PHOASExpr typ v f t +-- | Tag values identify a variable, and are created in the sharing recovery +-- process (in 'pruneExpr'). Applying the sharing recovery algorithm to a term +-- with manually constructed Tag values, or inspecting Tag values in the +-- expression and computing with them or branching based on them, results in +-- "impossible terms" and behaviour of the sharing recovery algorithm is +-- undefined. Consider Tag an opaque value. newtype Tag t = Tag Natural deriving (Show, Eq) deriving (Hashable) via Natural @@ -125,16 +134,26 @@ instance TestEquality (NameFor typ f) where unsafeCoerceRefl :: a :~: b -- restricted version of unsafeCoerce that only allows punting proofs unsafeCoerceRefl = unsafeCoerce Refl +typeOfPHOAS :: PHOASExpr typ v f t -> typ t +typeOfPHOAS (PHOASOp ty _) = ty +typeOfPHOAS (PHOASLam ty _ _) = ty +typeOfPHOAS (PHOASVar ty _) = ty + -- | Pruned expression. -- -- Note that variables do not, and will never, have a name: we don't bother -- detecting sharing for variable references, because that would only introduce -- a redundant variable indirection. -data PExpr typ f t where - PStub :: NameFor typ f t -> typ t -> PExpr typ f t - POp :: NameFor typ f t -> typ t -> f (PExpr typ f) t -> PExpr typ f t - PLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> PExpr typ f b -> PExpr typ f (a -> b) - PVar :: typ a -> Tag a -> PExpr typ f a +-- +-- This is defined as a base functor; @r@ is the recursive position. +data PExpr r typ f t where + PStub :: NameFor typ f t -> typ t -> PExpr r typ f t + POp :: NameFor typ f t -> typ t -> f (r typ f) t -> PExpr r typ f t + PLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> r typ f b -> PExpr r typ f (a -> b) + PVar :: typ a -> Tag a -> PExpr r typ f a + +-- | Fixpoint of 'PExpr' +newtype PExpr0 typ f t = PExpr0 (PExpr PExpr0 typ f t) data SomeNameFor typ f = forall t. SomeNameFor {-# UNPACK #-} !(NameFor typ f t) @@ -144,19 +163,22 @@ instance Eq (SomeNameFor typ f) where instance Hashable (SomeNameFor typ f) where hashWithSalt salt (SomeNameFor name) = hashWithSalt salt name -prettyPExpr :: Traversable1 f => Int -> PExpr typ f t -> ShowS -prettyPExpr d = \case +prettyPExpr0 :: Traversable1 f => Int -> PExpr0 typ f t -> ShowS +prettyPExpr0 d (PExpr0 ex) = prettyPExpr prettyPExpr0 d ex + +prettyPExpr :: Traversable1 f => (forall a. Int -> r typ f a -> ShowS) -> Int -> PExpr r typ f t -> ShowS +prettyPExpr recur d = \case PStub (NameFor name) _ -> showString (showStableName name) POp (NameFor name) _ args -> let (argslist, _) = traverse1 (\arg -> ([Some arg], Const ())) args - argslist' = map (\(Some arg) -> prettyPExpr 0 arg) argslist + argslist' = map (\(Some arg) -> recur 0 arg) argslist in showParen (d > 10) $ showString ("<" ++ showStableName name ++ ">(") . foldr (.) id (intersperse (showString ", ") argslist') . showString ")" PLam (NameFor name) _ _ (Tag tag) body -> showParen (d > 0) $ - showString ("λ" ++ showStableName name ++ " x" ++ show tag ++ ". ") . prettyPExpr 0 body + showString ("λ" ++ showStableName name ++ " x" ++ show tag ++ ". ") . recur 0 body PVar _ (Tag tag) -> showString ("x" ++ show tag) -- | For each name: @@ -170,13 +192,19 @@ prettyPExpr d = \case -- Missing names have not been seen yet, and have unknown height. type OccMap typ f = HashMap (SomeNameFor typ f) (Natural, Natural) -pruneExpr :: Traversable1 f => (forall v. PHOASExpr typ v f t) -> (OccMap typ f, PExpr typ f t) +pruneExpr :: Traversable1 f => (forall v. PHOASExpr typ v f t) -> (OccMap typ f, PExpr0 typ f t) pruneExpr term = let ((term', _), (_, mp)) = runState (pruneExpr' term) (0, mempty) in (mp, term') +pruneExprUnsafe :: Traversable1 f => PHOASExpr typ Tag f t -> (OccMap typ f, PExpr0 typ f t) +pruneExprUnsafe term = + let ((term', _), (_, mp)) = runState (pruneExpr' term) (0, mempty) + in (mp, term') + -- | Returns pruned expression with its height. -pruneExpr' :: Traversable1 f => PHOASExpr typ Tag f t -> State (Natural, OccMap typ f) (PExpr typ f t, Natural) +-- State: (ID generator, occurrence map being accumulated) +pruneExpr' :: Traversable1 f => PHOASExpr typ Tag f t -> State (Natural, OccMap typ f) (PExpr0 typ f t, Natural) pruneExpr' = \case orig@(PHOASOp ty args) -> do let name = makeStableName' orig @@ -185,7 +213,7 @@ pruneExpr' = \case -- already visited Just height -> do modify (second (HM.adjust (first (+1)) (SomeNameFor (NameFor name)))) - pure (PStub (NameFor name) ty, height) + pure (PExpr0 (PStub (NameFor name) ty), height) -- first visit Nothing -> do -- Traverse the arguments, collecting the maximum height in an @@ -193,13 +221,14 @@ pruneExpr' = \case (args', maxhei) <- withMoreState 0 $ traverse1 (\arg -> do + -- drop the extra state for the recursive call (arg', hei) <- withLessState id (,) (pruneExpr' arg) - modify (second (hei `max`)) + modify (second (hei `max`)) -- modify the extra state return arg') args -- Record this node modify (second (HM.insert (SomeNameFor (NameFor name)) (1, 1 + maxhei))) - pure (POp (NameFor name) ty args', 1 + maxhei) + pure (PExpr0 (POp (NameFor name) ty args'), 1 + maxhei) orig@(PHOASLam tyf tyarg f) -> do let name = makeStableName' orig @@ -208,7 +237,7 @@ pruneExpr' = \case -- already visited Just height -> do modify (second (HM.adjust (first (+1)) (SomeNameFor (NameFor name)))) - pure (PStub (NameFor name) tyf, height) + pure (PExpr0 (PStub (NameFor name) tyf), height) -- first visit Nothing -> do tag <- Tag <$> gets fst @@ -216,45 +245,24 @@ pruneExpr' = \case let body = f tag (body', bodyhei) <- pruneExpr' body modify (second (HM.insert (SomeNameFor (NameFor name)) (1, 1 + bodyhei))) - pure (PLam (NameFor name) tyf tyarg tag body', 1 + bodyhei) + pure (PExpr0 (PLam (NameFor name) tyf tyarg tag body'), 1 + bodyhei) - PHOASVar ty tag -> pure (PVar ty tag, 1) + PHOASVar ty tag -> pure (PExpr0 (PVar ty tag), 1) --- | Floated expression: a bunch of to-be let bound expressions on top of an --- LExpr'. Because LExpr' is really just PExpr with the recursive positions --- replaced by LExpr, LExpr should be seen as PExpr with a bunch of to-be let --- bound expressions on top of every node. -data LExpr typ f t = LExpr [Some (LExpr typ f)] (LExpr' typ f t) -data LExpr' typ f t where -- TODO: this could be an instantiation of (a generalisation of) PExpr - LStub :: NameFor typ f t -> typ t -> LExpr' typ f t - LOp :: NameFor typ f t -> typ t -> f (LExpr typ f) t -> LExpr' typ f t - LLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> LExpr typ f b -> LExpr' typ f (a -> b) - LVar :: typ a -> Tag a -> LExpr' typ f a +-- | Floated expression: again a 'PExpr' (it's a fixpoint over the same base +-- functor), but now with a bunch of to-be let bound expressions on top of +-- every node. +data LExpr typ f t = LExpr [Some (LExpr typ f)] (PExpr LExpr typ f t) prettyLExpr :: Traversable1 f => Int -> LExpr typ f t -> ShowS -prettyLExpr d (LExpr [] e) = prettyLExpr' d e +prettyLExpr d (LExpr [] e) = prettyPExpr prettyLExpr d e prettyLExpr d (LExpr floated e) = showString "[" . foldr (.) id (intersperse (showString ", ") (map (\(Some e') -> prettyLExpr 0 e') floated)) - . showString "] " . prettyLExpr' d e - -prettyLExpr' :: Traversable1 f => Int -> LExpr' typ f t -> ShowS -prettyLExpr' d = \case - LStub (NameFor name) _ -> showString (showStableName name) - LOp (NameFor name) _ args -> - let (argslist, _) = traverse1 (\arg -> ([Some arg], Const ())) args - argslist' = map (\(Some arg) -> prettyLExpr 0 arg) argslist - in showParen (d > 10) $ - showString ("<" ++ showStableName name ++ ">(") - . foldr (.) id (intersperse (showString ", ") argslist') - . showString ")" - LLam (NameFor name) _ _ (Tag tag) body -> - showParen (d > 0) $ - showString ("λ" ++ showStableName name ++ " x" ++ show tag ++ ". ") . prettyLExpr 0 body - LVar _ (Tag tag) -> showString ("x" ++ show tag) + . showString "] " . prettyPExpr prettyLExpr d e -floatExpr :: Traversable1 f => OccMap typ f -> PExpr typ f t -> LExpr typ f t +floatExpr :: Traversable1 f => OccMap typ f -> PExpr0 typ f t -> LExpr typ f t floatExpr totals term = snd (floatExpr' totals term) newtype FoundMap typ f = FoundMap @@ -269,46 +277,47 @@ instance Semigroup (FoundMap typ f) where instance Monoid (FoundMap typ f) where mempty = FoundMap HM.empty -floatExpr' :: Traversable1 f => OccMap typ f -> PExpr typ f t -> (FoundMap typ f, LExpr typ f t) -floatExpr' _totals (PStub name ty) = - -- trace ("Found stub: " ++ (case name of NameFor n -> showStableName n)) $ - (FoundMap $ HM.singleton (SomeNameFor name) (1, Nothing) - ,LExpr [] (LStub name ty)) +floatExpr' :: Traversable1 f => OccMap typ f -> PExpr0 typ f t -> (FoundMap typ f, LExpr typ f t) +floatExpr' totals (PExpr0 term) = case term of + PStub name ty -> + -- trace ("Found stub: " ++ (case name of NameFor n -> showStableName n)) $ + (FoundMap $ HM.singleton (SomeNameFor name) (1, Nothing) + ,LExpr [] (PStub name ty)) -floatExpr' _totals (PVar ty tag) = - -- trace ("Found var: " ++ show tag) $ - (mempty, LExpr [] (LVar ty tag)) + PVar ty tag -> + -- trace ("Found var: " ++ show tag) $ + (mempty, LExpr [] (PVar ty tag)) -floatExpr' totals term = - let (FoundMap foundmap, name, termty, term') = case term of - POp n ty args -> - let (fm, args') = traverse1 (floatExpr' totals) args - in (fm, n, ty, LOp n ty args') - PLam n tyf tyarg tag body -> - let (fm, body') = floatExpr' totals body - in (fm, n, tyf, LLam n tyf tyarg tag body') + _ -> + let (FoundMap foundmap, name, termty, term') = case term of + POp n ty args -> + let (fm, args') = traverse1 (floatExpr' totals) args + in (fm, n, ty, POp n ty args') + PLam n tyf tyarg tag body -> + let (fm, body') = floatExpr' totals body + in (fm, n, tyf, PLam n tyf tyarg tag body') - -- TODO: perhaps this HM.toList together with the foldr HM.delete can be a single traversal of the HashMap - saturated = [case mterm of - Just t -> (nm, t) - Nothing -> case nm of - SomeNameFor (NameFor n) -> - error $ "Name saturated (count=" ++ show count ++ ", totalcount=" ++ show totalcount ++ ") but no term found: " ++ showStableName n - | (nm, (count, mterm)) <- HM.toList foundmap - , let totalcount = fromMaybe 0 (fst <$> HM.lookup nm totals) - , count == totalcount] + -- TODO: perhaps this HM.toList together with the foldr HM.delete can be a single traversal of the HashMap + saturated = [case mterm of + Just t -> (nm, t) + Nothing -> case nm of + SomeNameFor (NameFor n) -> + error $ "Name saturated (count=" ++ show count ++ ", totalcount=" ++ show totalcount ++ ") but no term found: " ++ showStableName n + | (nm, (count, mterm)) <- HM.toList foundmap + , let totalcount = fromMaybe 0 (fst <$> HM.lookup nm totals) + , count == totalcount] - foundmap' = foldr HM.delete foundmap (map fst saturated) + foundmap' = foldr HM.delete foundmap (map fst saturated) - lterm = LExpr (map fst (sortBy (comparing snd) (map snd saturated))) term' + lterm = LExpr (map fst (sortBy (comparing snd) (map snd saturated))) term' - in case HM.findWithDefault (0, undefined) (SomeNameFor name) totals of - (1, _) -> (FoundMap foundmap', lterm) - (tot, height) - | tot > 1 -> -- trace ("Inserting " ++ (case name of NameFor n -> showStableName n) ++ " into foundmap") $ - (FoundMap (HM.insert (SomeNameFor name) (1, Just (Some lterm, height)) foundmap') - ,LExpr [] (LStub name termty)) - | otherwise -> error "Term does not exist, yet we have it in hand" + in case HM.findWithDefault (0, undefined) (SomeNameFor name) totals of + (1, _) -> (FoundMap foundmap', lterm) + (tot, height) + | tot > 1 -> -- trace ("Inserting " ++ (case name of NameFor n -> showStableName n) ++ " into foundmap") $ + (FoundMap (HM.insert (SomeNameFor name) (1, Just (Some lterm, height)) foundmap') + ,LExpr [] (PStub name termty)) + | otherwise -> error "Term does not exist, yet we have it in hand" -- | Untyped De Bruijn expression. No more names: there are lets now, and @@ -342,15 +351,15 @@ lowerExpr' namelvl taglvl curlvl (LExpr floated ex) = curlvl' = curlvl + length floated in prefix $ case ex of - LStub name ty -> + PStub name ty -> case HM.lookup (SomeNameFor name) namelvl' of Just lvl -> UBVar ty (curlvl - lvl - 1) Nothing -> error "Name variable out of scope" - LOp _ ty args -> + POp _ ty args -> UBOp ty (fmap1 (lowerExpr' namelvl' taglvl curlvl') args) - LLam _ tyf tyarg tag body -> + PLam _ tyf tyarg tag body -> UBLam tyf tyarg (lowerExpr' namelvl' (HM.insert (SomeTag tag) curlvl' taglvl) (curlvl' + 1) body) - LVar ty tag -> + PVar ty tag -> case HM.lookup (SomeTag tag) taglvl of Just lvl -> UBVar ty (curlvl - lvl - 1) Nothing -> error "Tag variable out of scope" @@ -363,17 +372,17 @@ lowerExpr' namelvl taglvl curlvl (LExpr floated ex) = buildPrefix namelvl' _ [] = (namelvl', id) buildPrefix namelvl' lvl (Some rhs@(LExpr _ rhs') : rhss) = let name = case rhs' of - LStub n _ -> n - LOp n _ _ -> n - LLam n _ _ _ _ -> n - LVar _ _ -> error "Recovering sharing of a tag is useless" + PStub n _ -> n + POp n _ _ -> n + PLam n _ _ _ _ -> n + PVar _ _ -> error "Recovering sharing of a tag is useless" ty = case rhs' of - LStub{} -> error "Recovering sharing of a stub is useless" - LOp _ t _ -> t - LLam _ t _ _ _ -> t - LVar t _ -> t + PStub{} -> error "Recovering sharing of a stub is useless" + POp _ t _ -> t + PLam _ t _ _ _ -> t + PVar t _ -> t prefix = UBLet ty (lowerExpr' namelvl' taglvl lvl rhs) - in (prefix .) <$> buildPrefix (HM.insert (SomeNameFor name) lvl namelvl') (lvl + 1) rhss + in second (prefix .) $ buildPrefix (HM.insert (SomeNameFor name) lvl namelvl') (lvl + 1) rhss -- | A typed De Bruijn index. @@ -399,7 +408,7 @@ envLookupU = go id go f 0 (EPush _ t) = Just (Some (Pair t (f IZ))) go f i (EPush e _) = go (f . IS) (i - 1) e --- | Typed De Bruijn expression. This is the resu,t of sharing recovery. It is +-- | Typed De Bruijn expression. This is the result of sharing recovery. It is -- not higher-order any more, and furthermore has explicit let-bindings ('BLet') -- that denote the sharing inside the term. This is a normal AST. -- @@ -418,6 +427,12 @@ data BExpr typ env f t where deriving instance (forall a. Show (typ a), forall a r. (forall b. Show (r b)) => Show (f r a)) => Show (BExpr typ env f t) +typeOfBExpr :: BExpr typ v f t -> typ t +typeOfBExpr (BOp ty _) = ty +typeOfBExpr (BLam ty _ _) = ty +typeOfBExpr (BLet _ _ e) = typeOfBExpr e +typeOfBExpr (BVar ty _) = ty + prettyBExpr :: (forall m env' a. Monad m => (forall b. Int -> BExpr typ env' f b -> m ShowS) -> Int -> f (BExpr typ env' f) a -> m ShowS) -> BExpr typ '[] f t -> String @@ -466,8 +481,16 @@ retypeExpr' env (UBVar ty idx) = -- abstract syntax form to a well-typed well-scoped De Bruijn expression with -- explicit let-bindings. sharingRecovery :: (Traversable1 f, TestEquality typ) => (forall v. PHOASExpr typ v f t) -> BExpr typ '[] f t -sharingRecovery e = - let (occmap, pexpr) = pruneExpr e +sharingRecovery = sharingRecoveryUnsafe + +-- | The 'sharingRecovery' function instantiates the @v@ parameter to 'Tag', +-- and this function is provided as a convenience in case constructing a +-- polymorphic expression value is difficult. However, it is /disallowed/ to +-- inspect or manipulate 'Tag' values obtained by lambda abstraction inside the +-- expression. Violating this rule results in undefined behaviour. +sharingRecoveryUnsafe :: (Traversable1 f, TestEquality typ) => PHOASExpr typ Tag f t -> BExpr typ '[] f t +sharingRecoveryUnsafe e = + let (occmap, pexpr) = pruneExprUnsafe e lexpr = floatExpr occmap pexpr ubexpr = lowerExpr lexpr in -- trace ("PExpr: " ++ prettyPExpr 0 pexpr "") $ diff --git a/test/Arith.hs b/test-th/Arith.hs index c34baa8..c34baa8 100644 --- a/test/Arith.hs +++ b/test-th/Arith.hs diff --git a/test/Arith/NonBase.hs b/test-th/Arith/NonBase.hs index f5d458e..f5d458e 100644 --- a/test/Arith/NonBase.hs +++ b/test-th/Arith/NonBase.hs diff --git a/test-th/Main.hs b/test-th/Main.hs new file mode 100644 index 0000000..5a4d335 --- /dev/null +++ b/test-th/Main.hs @@ -0,0 +1,51 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} +module Main where + +import Data.Expr.SharingRecovery +import Data.Expr.SharingRecovery.Internal + +import Arith + + +-- TODO: test cyclic expressions + + +a_bin :: (KnownType a, KnownType b, KnownType c) + => PrimOp (a, b) c + -> PHOASExpr Typ v ArithF a + -> PHOASExpr Typ v ArithF b + -> PHOASExpr Typ v ArithF c +a_bin op a b = PHOASOp τ (A_Prim op (PHOASOp τ (A_Pair a b))) + +lam :: (KnownType a, KnownType b) + => (PHOASExpr Typ v f a -> PHOASExpr Typ v f b) -> PHOASExpr Typ v f (a -> b) +lam f = PHOASLam τ τ $ \arg -> f (PHOASVar τ arg) + +(+!) :: PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int +(+!) = a_bin POAddI + +(*!) :: PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int +(*!) = a_bin POMulI + +-- λx. x + x +ea_1 :: PHOASExpr Typ v ArithF (Int -> Int) +ea_1 = lam $ \arg -> arg +! arg + +-- λx. let y = x + x in y * y +ea_2 :: PHOASExpr Typ v ArithF (Int -> Int) +ea_2 = lam $ \arg -> let y = arg +! arg + in y *! y + +ea_3 :: PHOASExpr Typ v ArithF (Int -> Int) +ea_3 = lam $ \arg -> + let y = arg +! arg + x = y *! arg + -- in (y +! x) +! (x +! y) + in (x +! y) +! (y +! x) + +main :: IO () +main = putStrLn $ prettyBExpr prettyArithF (sharingRecovery ea_3) diff --git a/test/NonBaseTH.hs b/test-th/NonBaseTH.hs index 4741ea0..4741ea0 100644 --- a/test/NonBaseTH.hs +++ b/test-th/NonBaseTH.hs diff --git a/test/Main.hs b/test/Main.hs index 5a4d335..208d3d6 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,51 +1,100 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE QuantifiedConstraints #-} module Main where import Data.Expr.SharingRecovery -import Data.Expr.SharingRecovery.Internal +import Data.Type.Equality + + +data Ty a where + TInt :: Ty Int + TFloat :: Ty Float + TBool :: Ty Bool +deriving instance Show (Ty a) -import Arith +instance TestEquality Ty where + testEquality TInt TInt = Just Refl + testEquality TInt _ = Nothing + testEquality TFloat TFloat = Just Refl + testEquality TFloat _ = Nothing + testEquality TBool TBool = Just Refl + testEquality TBool _ = Nothing +type family IsOrdTy a where + IsOrdTy Int = True + IsOrdTy Float = True + IsOrdTy _ = False --- TODO: test cyclic expressions +data Unop a b where + UONeg :: Ty a -> Unop a a + UONot :: Unop Bool Bool +deriving instance Show (Unop a b) +data Binop a b c where + BOAdd :: Ty a -> Binop a a a + BOSub :: Ty a -> Binop a a a + BOMul :: Ty a -> Binop a a a + BOAnd :: Binop Bool Bool Bool + BOOr :: Binop Bool Bool Bool + BOLt :: IsOrdTy a ~ True => Ty a -> Binop a a Bool + BOLeq :: IsOrdTy a ~ True => Ty a -> Binop a a Bool + BOEq :: IsOrdTy a ~ True => Ty a -> Binop a a Bool + BONeq :: IsOrdTy a ~ True => Ty a -> Binop a a Bool +deriving instance Show (Binop a b c) -a_bin :: (KnownType a, KnownType b, KnownType c) - => PrimOp (a, b) c - -> PHOASExpr Typ v ArithF a - -> PHOASExpr Typ v ArithF b - -> PHOASExpr Typ v ArithF c -a_bin op a b = PHOASOp τ (A_Prim op (PHOASOp τ (A_Pair a b))) +data Lang r a where + Un :: Unop a b -> r a -> Lang r b + Bin :: Binop a b c -> r a -> r b -> Lang r c + Cond :: r Bool -> r a -> r a -> Lang r a + Cnst :: Show a => a -> Lang r a -- there's a type in the BExpr in the end, no need for one here +deriving instance (forall b. Show (r b)) => Show (Lang r a) -lam :: (KnownType a, KnownType b) - => (PHOASExpr Typ v f a -> PHOASExpr Typ v f b) -> PHOASExpr Typ v f (a -> b) -lam f = PHOASLam τ τ $ \arg -> f (PHOASVar τ arg) +instance Functor1 Lang +instance Traversable1 Lang where + traverse1 f = \case + Un op x -> Un op <$> f x + Bin op x y -> Bin op <$> f x <*> f y + Cond x y z -> Cond <$> f x <*> f y <*> f z + Cnst v -> pure (Cnst v) -(+!) :: PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -(+!) = a_bin POAddI +class KnownTy a where knownTy :: Ty a +instance KnownTy Int where knownTy = TInt +instance KnownTy Float where knownTy = TFloat +instance KnownTy Bool where knownTy = TBool -(*!) :: PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -(*!) = a_bin POMulI +type Expr v = PHOASExpr Ty v Lang --- λx. x + x -ea_1 :: PHOASExpr Typ v ArithF (Int -> Int) -ea_1 = lam $ \arg -> arg +! arg +cond :: KnownTy a => Expr v Bool -> Expr v a -> Expr v a -> Expr v a +cond a b c = PHOASOp knownTy (Cond a b c) --- λx. let y = x + x in y * y -ea_2 :: PHOASExpr Typ v ArithF (Int -> Int) -ea_2 = lam $ \arg -> let y = arg +! arg - in y *! y +(.<), (.<=), (.>), (.>=) :: (KnownTy a, IsOrdTy a ~ True) => Expr v a -> Expr v a -> Expr v Bool +a .< b = PHOASOp TBool (Bin (BOLt knownTy) a b) +a .<= b = PHOASOp TBool (Bin (BOLeq knownTy) a b) +(.>) = flip (.<) +(.>=) = flip (.<=) +infix 4 .< +infix 4 .<= +infix 4 .> +infix 4 .>= -ea_3 :: PHOASExpr Typ v ArithF (Int -> Int) -ea_3 = lam $ \arg -> - let y = arg +! arg - x = y *! arg - -- in (y +! x) +! (x +! y) - in (x +! y) +! (y +! x) +instance (KnownTy a, IsOrdTy a ~ True, Num a, Show a) => Num (Expr v a) where + a + b = PHOASOp knownTy (Bin (BOAdd knownTy) a b) + a - b = PHOASOp knownTy (Bin (BOSub knownTy) a b) + a * b = PHOASOp knownTy (Bin (BOMul knownTy) a b) + negate a = PHOASOp knownTy (Un (UONeg knownTy) a) + abs a = cond (a .< 0) (-a) a + signum a = cond (a .< 0) (-1) (cond (a .> 0) 1 0) + fromInteger n = PHOASOp knownTy (Cnst (fromInteger n)) main :: IO () -main = putStrLn $ prettyBExpr prettyArithF (sharingRecovery ea_3) +main = do + print $ sharingRecovery @Lang @_ $ + let a = 2 ; b = 3 :: Expr v Int + in a + b .< b + a |
