aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--sharing-recovery.cabal11
-rw-r--r--src/Data/Expr/SharingRecovery.hs193
-rw-r--r--test/Main.hs114
4 files changed, 273 insertions, 46 deletions
diff --git a/.gitignore b/.gitignore
index c33954f..a3ac1fc 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,2 @@
dist-newstyle/
+cabal.project.local
diff --git a/sharing-recovery.cabal b/sharing-recovery.cabal
index a994e03..df63c42 100644
--- a/sharing-recovery.cabal
+++ b/sharing-recovery.cabal
@@ -8,6 +8,7 @@ build-type: Simple
library
exposed-modules:
Data.Expr.SharingRecovery
+ other-modules:
Data.StableName.Extra
build-depends:
base >=4.16,
@@ -18,3 +19,13 @@ library
hs-source-dirs: src
default-language: Haskell2010
ghc-options: -Wall
+
+test-suite test
+ type: exitcode-stdio-1.0
+ main-is: Main.hs
+ hs-source-dirs: test
+ build-depends:
+ sharing-recovery,
+ base,
+ default-language: Haskell2010
+ ghc-options: -Wall
diff --git a/src/Data/Expr/SharingRecovery.hs b/src/Data/Expr/SharingRecovery.hs
index e386f4e..cdb64eb 100644
--- a/src/Data/Expr/SharingRecovery.hs
+++ b/src/Data/Expr/SharingRecovery.hs
@@ -1,17 +1,27 @@
+{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DerivingVia #-}
+{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
module Data.Expr.SharingRecovery where
import Control.Applicative ((<|>))
import Control.Monad.Trans.State.Strict
import Data.Bifunctor (second)
-import Data.GADT.Compare
+import Data.Char (chr, ord)
+import Data.Functor.Const
+import Data.Functor.Identity
+import Data.Functor.Product
import Data.Hashable
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HM
@@ -24,8 +34,6 @@ import Unsafe.Coerce (unsafeCoerce)
import Data.StableName.Extra
--- TODO: This is not yet done, see the bottom of this file
-
-- TODO: This implementation needs extensive documentation. 1. It is written
-- quite generically, meaning that the actual algorithm is easily obscured to
-- all but the most willing readers; 2. the original paper leaves something to
@@ -36,24 +44,28 @@ import Data.StableName.Extra
class Functor1 f where
fmap1 :: (forall b. g b -> h b) -> f g a -> f h a
+ default fmap1 :: Traversable1 f => (forall b. g b -> h b) -> f g a -> f h a
+ fmap1 f x = runIdentity (traverse1 (Identity . f) x)
+
class Functor1 f => Traversable1 f where
traverse1 :: Applicative m => (forall b. g b -> m (h b)) -> f g a -> m (f h a)
-- | Expression in parametric higher-order abstract syntax form
data PHOASExpr typ v f t where
PHOASOp :: typ t -> f (PHOASExpr typ v f) t -> PHOASExpr typ v f t
- PHOASLam :: typ (a -> b) -> typ a -> (PHOASExpr typ v f a -> PHOASExpr typ v f b) -> PHOASExpr typ v f (a -> b)
+ PHOASLam :: typ (a -> b) -> typ a -> (v a -> PHOASExpr typ v f b) -> PHOASExpr typ v f (a -> b)
PHOASVar :: typ t -> v t -> PHOASExpr typ v f t
newtype Tag t = Tag Natural
deriving (Show, Eq)
+ deriving (Hashable) via Natural
newtype NameFor typ f t = NameFor (StableName (PHOASExpr typ Tag f t))
deriving (Eq)
deriving (Hashable) via (StableName (PHOASExpr typ Tag f t))
-instance GEq (NameFor typ f) where
- geq (NameFor n1) (NameFor n2)
+instance TestEquality (NameFor typ f) where
+ testEquality (NameFor n1) (NameFor n2)
| eqStableName n1 n2 = Just unsafeCoerceRefl
| otherwise = Nothing
where
@@ -106,7 +118,7 @@ pruneExpr' = \case
then pure $ PStub (NameFor name) tyf
else do
tag <- state (\(i, mp) -> (Tag i, (i + 1, mp)))
- let body = f (PHOASVar tyarg tag)
+ let body = f tag
PLam (NameFor name) tyf tyarg tag <$> pruneExpr' body
PHOASVar ty tag -> pure $ PVar ty tag
@@ -181,56 +193,145 @@ liftExpr' totals term =
| otherwise -> error "Term does not exist, yet we have it in hand"
--- | Errors on stubs.
-lexprTypeOf :: LExpr typ f t -> typ t
-lexprTypeOf (LExpr _ e) = case e of
- LStub{} -> error "lexprTypeOf: got a stub"
- LOp _ t _ -> t
- LLam _ t _ _ _ -> t
- LVar t _ -> t
-
-
--- TODO: lower LExpr into a normal expression with let bindings. Every LStub
--- should correspond to some let-bound expression higher up in the tree (if it
--- does not, that's a bug), and should become a De Bruijn variable reference to
--- said let-bound expression. Lambdas should also get proper De Bruijn indices
--- instead of tags, and LVar is also a normal variable (referring to a
--- lambda-abstracted argument).
-
-- | Untyped De Bruijn expression. No more names: there are lets now, and
-- variable references are De Bruijn indices. These indices are not type-safe
-- yet, though.
data UBExpr typ f t where
UBOp :: typ t -> f (UBExpr typ f) t -> UBExpr typ f t
- UBLam :: typ a -> UBExpr typ f b -> UBExpr typ f (a -> b)
+ UBLam :: typ (a -> b) -> typ a -> UBExpr typ f b -> UBExpr typ f (a -> b)
UBLet :: typ a -> UBExpr typ f a -> UBExpr typ f b -> UBExpr typ f b
- -- De Bruijn index
+ -- | De Bruijn index
UBVar :: typ t -> Int -> UBExpr typ f t
-lowerExpr :: LExpr typ f t -> UBExpr typ f t
-lowerExpr = lowerExpr' mempty 0
+lowerExpr :: Functor1 f => LExpr typ f t -> UBExpr typ f t
+lowerExpr = lowerExpr' mempty mempty 0
--- 1. name |-> De Bruijn level of the variable defining that name
--- 2. Number of variables already in scope
-lowerExpr' :: forall typ f t. Traversable1 f => HashMap (SomeNameFor typ f) Int -> Int -> LExpr typ f t -> UBExpr typ f t
-lowerExpr' namelvl curlvl (LExpr lifted ex) =
- let prefix = buildPrefix curlvl lifted
- curlvl' = curlvl + length lifted
- in case ex of
- LStub name ty ->
- case HM.lookup (SomeNameFor name) namelvl of
- Just lvl -> UBVar ty (curlvl - lvl - 1)
- Nothing -> error "Variable out of scope"
- LOp name ty args ->
- UBOp ty (_ $ traverse1 _ args)
- where
- buildPrefix :: forall b. Int -> [Some (LExpr typ f)] -> UBExpr typ f b -> UBExpr typ f b
- buildPrefix _ [] = id
- buildPrefix lvl (Some rhs : rhss) =
- UBLet (lexprTypeOf rhs) (lowerExpr' namelvl lvl rhs)
- . buildPrefix (lvl + 1) rhss
+data SomeTag = forall t. SomeTag (Tag t)
+
+instance Eq SomeTag where
+ SomeTag (Tag n) == SomeTag (Tag m) = n == m
+instance Hashable SomeTag where
+ hashWithSalt salt (SomeTag tag) = hashWithSalt salt tag
+lowerExpr' :: forall typ f t. Functor1 f
+ => HashMap (SomeNameFor typ f) Int -- ^ node |-> De Bruijn level of defining binding
+ -> HashMap SomeTag Int -- ^ tag |-> De Bruijn level of defining binding
+ -> Int -- ^ Number of variables already in scope
+ -> LExpr typ f t -> UBExpr typ f t
+lowerExpr' namelvl taglvl curlvl (LExpr lifted ex) =
+ let (namelvl', prefix) = buildPrefix namelvl curlvl lifted
+ curlvl' = curlvl + length lifted
+ in prefix $
+ case ex of
+ LStub name ty ->
+ case HM.lookup (SomeNameFor name) namelvl' of
+ Just lvl -> UBVar ty (curlvl - lvl - 1)
+ Nothing -> error "Name variable out of scope"
+ LOp _ ty args ->
+ UBOp ty (fmap1 (lowerExpr' namelvl' taglvl curlvl') args)
+ LLam _ tyf tyarg tag body ->
+ UBLam tyf tyarg (lowerExpr' namelvl' (HM.insert (SomeTag tag) curlvl' taglvl) (curlvl' + 1) body)
+ LVar ty tag ->
+ case HM.lookup (SomeTag tag) taglvl of
+ Just lvl -> UBVar ty (curlvl - lvl - 1)
+ Nothing -> error "Tag variable out of scope"
+ where
+ buildPrefix :: forall b.
+ HashMap (SomeNameFor typ f) Int
+ -> Int
+ -> [Some (LExpr typ f)]
+ -> (HashMap (SomeNameFor typ f) Int, UBExpr typ f b -> UBExpr typ f b)
+ buildPrefix namelvl' _ [] = (namelvl', id)
+ buildPrefix namelvl' lvl (Some rhs@(LExpr _ rhs') : rhss) =
+ let name = case rhs' of
+ LStub n _ -> n
+ LOp n _ _ -> n
+ LLam n _ _ _ _ -> n
+ LVar _ _ -> error "Recovering sharing of a tag is useless"
+ ty = case rhs' of
+ LStub{} -> error "Recovering sharing of a stub is useless"
+ LOp _ t _ -> t
+ LLam _ t _ _ _ -> t
+ LVar t _ -> t
+ prefix = UBLet ty (lowerExpr' namelvl' taglvl lvl rhs)
+ in (prefix .) <$> buildPrefix (HM.insert (SomeNameFor name) lvl namelvl') (lvl + 1) rhss
+
+
+-- | A typed De Bruijn index.
data Idx env t where
IZ :: Idx (t : env) t
IS :: Idx env t -> Idx (s : env) t
+deriving instance Show (Idx env t)
+
+data Env env f where
+ ETop :: Env '[] f
+ EPush :: Env env f -> f t -> Env (t : env) f
+
+envLookup :: Idx env t -> Env env f -> f t
+envLookup IZ (EPush _ x) = x
+envLookup (IS i) (EPush e _) = envLookup i e
+
+-- | Untyped lookup in an 'Env'.
+envLookupU :: Int -> Env env f -> Maybe (Some (Product f (Idx env)))
+envLookupU = go id
+ where
+ go :: (forall a. Idx env a -> Idx env' a) -> Int -> Env env f -> Maybe (Some (Product f (Idx env')))
+ go !_ !_ ETop = Nothing
+ go f 0 (EPush _ t) = Just (Some (Pair t (f IZ)))
+ go f i (EPush e _) = go (f . IS) (i - 1) e
+
+-- | Typed De Bruijn expression. This is the final result of sharing recovery.
+data BExpr typ env f t where
+ BOp :: typ t -> f (BExpr typ env f) t -> BExpr typ env f t
+ BLam :: typ (a -> b) -> typ a -> BExpr typ (a : env) f b -> BExpr typ env f (a -> b)
+ BLet :: typ a -> BExpr typ env f a -> BExpr typ (a : env) f b -> BExpr typ env f b
+ BVar :: typ t -> Idx env t -> BExpr typ env f t
+deriving instance (forall a. Show (typ a), forall a r. (forall b. Show (r b)) => Show (f r a))
+ => Show (BExpr typ env f t)
+
+prettyBExpr :: (forall m env' a. Monad m => (forall b. Int -> BExpr typ env' f b -> m ShowS)
+ -> Int -> f (BExpr typ env' f) a -> m ShowS)
+ -> BExpr typ '[] f t -> String
+prettyBExpr prettyOp e = evalState (prettyBExpr' prettyOp ETop 0 e) 0 ""
+
+prettyBExpr' :: (forall m env' a. Monad m => (forall b. Int -> BExpr typ env' f b -> m ShowS)
+ -> Int -> f (BExpr typ env' f) a -> m ShowS)
+ -> Env env (Const String) -> Int -> BExpr typ env f t -> State Int ShowS
+prettyBExpr' prettyOp env d = \case
+ BOp _ args ->
+ prettyOp (prettyBExpr' prettyOp env) d args
+ BLam _ _ body -> do
+ name <- genName
+ body' <- prettyBExpr' prettyOp (EPush env (Const name)) 0 body
+ return $ showParen (d > 0) $ showString ("λ" ++ name ++ ". ") . body'
+ BLet _ rhs body -> do
+ name <- genName
+ rhs' <- prettyBExpr' prettyOp env 0 rhs
+ body' <- prettyBExpr' prettyOp (EPush env (Const name)) 0 body
+ return $ showParen (d > 0) $ showString ("let " ++ name ++ " = ") . rhs' . showString " in " . body'
+ BVar _ idx ->
+ return $ showString (getConst (envLookup idx env))
+ where
+ genName = do
+ i <- state (\i -> (i, i + 1))
+ return $ if i < 26 then [chr (ord 'a' + i)] else 'x' : show i
+
+retypeExpr :: (Functor1 f, TestEquality typ) => UBExpr typ f t -> BExpr typ '[] f t
+retypeExpr = retypeExpr' ETop
+
+retypeExpr' :: (Functor1 f, TestEquality typ) => Env env typ -> UBExpr typ f t -> BExpr typ env f t
+retypeExpr' env (UBOp ty args) = BOp ty (fmap1 (retypeExpr' env) args)
+retypeExpr' env (UBLam tyf tyarg body) = BLam tyf tyarg (retypeExpr' (EPush env tyarg) body)
+retypeExpr' env (UBLet ty rhs body) = BLet ty (retypeExpr' env rhs) (retypeExpr' (EPush env ty) body)
+retypeExpr' env (UBVar ty idx) =
+ case envLookupU idx env of
+ Just (Some (Pair defty tidx)) ->
+ case testEquality ty defty of
+ Just Refl -> BVar ty tidx
+ Nothing -> error "Type mismatch in untyped De Bruijn expression"
+ Nothing -> error "Untyped De Bruijn index out of range"
+
+
+sharingRecovery :: (Traversable1 f, TestEquality typ) => (forall v. PHOASExpr typ v f t) -> BExpr typ '[] f t
+sharingRecovery e = retypeExpr $ lowerExpr $ uncurry liftExpr $ pruneExpr e
diff --git a/test/Main.hs b/test/Main.hs
new file mode 100644
index 0000000..e7b303b
--- /dev/null
+++ b/test/Main.hs
@@ -0,0 +1,114 @@
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+module Main where
+
+import Data.Type.Equality
+
+import Data.Expr.SharingRecovery
+
+
+data Typ t where
+ TInt :: Typ Int
+ TBool :: Typ Bool
+ TPair :: Typ a -> Typ b -> Typ (a, b)
+ TFun :: Typ a -> Typ b -> Typ (a -> b)
+deriving instance Show (Typ t)
+
+instance TestEquality Typ where
+ testEquality TInt TInt = Just Refl
+ testEquality TBool TBool = Just Refl
+ testEquality (TPair a b) (TPair a' b')
+ | Just Refl <- testEquality a a'
+ , Just Refl <- testEquality b b'
+ = Just Refl
+ testEquality (TFun a b) (TFun a' b')
+ | Just Refl <- testEquality a a'
+ , Just Refl <- testEquality b b'
+ = Just Refl
+ testEquality _ _ = Nothing
+
+class KnownType t where τ :: Typ t
+instance KnownType Int where τ = TInt
+instance KnownType Bool where τ = TBool
+instance (KnownType a, KnownType b) => KnownType (a, b) where τ = TPair τ τ
+instance (KnownType a, KnownType b) => KnownType (a -> b) where τ = TFun τ τ
+
+data PrimOp a b where
+ POAddI :: PrimOp (Int, Int) Int
+ POMulI :: PrimOp (Int, Int) Int
+ POEqI :: PrimOp (Int, Int) Bool
+deriving instance Show (PrimOp a b)
+
+data Fixity = Infix | Prefix
+ deriving (Show)
+
+primOpPrec :: PrimOp a b -> (Int, (Int, Int))
+primOpPrec POAddI = (6, (6, 7))
+primOpPrec POMulI = (7, (7, 8))
+primOpPrec POEqI = (4, (5, 5))
+
+prettyPrimOp :: Fixity -> PrimOp a b -> ShowS
+prettyPrimOp fix op =
+ let s = case op of
+ POAddI -> "+"
+ POMulI -> "*"
+ POEqI -> "=="
+ in showString $ case fix of
+ Infix -> s
+ Prefix -> "(" ++ s ++ ")"
+
+data ArithF r t where
+ A_Prim :: PrimOp a b -> r a -> ArithF r b
+ A_Pair :: r a -> r b -> ArithF r (a, b)
+ A_If :: r Bool -> r a -> r a -> ArithF r a
+deriving instance (forall a. Show (r a)) => Show (ArithF r t)
+
+instance Functor1 ArithF
+instance Traversable1 ArithF where
+ traverse1 f (A_Prim op x) = A_Prim op <$> f x
+ traverse1 f (A_Pair x y) = A_Pair <$> f x <*> f y
+ traverse1 f (A_If x y z) = A_If <$> f x <*> f y <*> f z
+
+prettyArithF :: Monad m
+ => (forall a. Int -> BExpr typ env ArithF a -> m ShowS)
+ -> Int -> ArithF (BExpr typ env ArithF) t -> m ShowS
+prettyArithF pr d = \case
+ A_Prim op (BOp _ (A_Pair a b)) -> do
+ let (dop, (dopL, dopR)) = primOpPrec op
+ a' <- pr dopL a
+ b' <- pr dopR b
+ return $ showParen (d > dop) $ a' . showString " " . prettyPrimOp Infix op . showString " " . b'
+ A_Prim op arg -> do
+ arg' <- pr 11 arg
+ return $ showParen (d > 10) $ prettyPrimOp Prefix op . showString " " . arg'
+ A_Pair a b -> do
+ a' <- pr 0 a
+ b' <- pr 0 b
+ return $ showString "(" . a' . showString ", " . b' . showString ")"
+ A_If a b c -> do
+ a' <- pr 0 a
+ b' <- pr 0 b
+ c' <- pr 0 c
+ return $ showParen (d > 0) $ showString "if " . a' . showString " then " . b' . showString " else " . c'
+
+-- λx. x + x
+ea_1 :: PHOASExpr Typ v ArithF (Int -> Int)
+ea_1 =
+ PHOASLam τ τ $ \arg ->
+ PHOASOp τ (A_Prim POAddI
+ (PHOASOp τ (A_Pair (PHOASVar τ arg) (PHOASVar τ arg))))
+
+-- λx. let y = x + x in y * y
+ea_2 :: PHOASExpr Typ v ArithF (Int -> Int)
+ea_2 =
+ PHOASLam τ τ $ \arg ->
+ let y = PHOASOp τ (A_Prim POAddI
+ (PHOASOp τ (A_Pair (PHOASVar τ arg) (PHOASVar τ arg))))
+ in PHOASOp τ (A_Prim POMulI
+ (PHOASOp τ (A_Pair y y)))
+
+main :: IO ()
+main = putStrLn $ prettyBExpr prettyArithF (sharingRecovery ea_2)