diff options
| -rw-r--r-- | .gitignore | 1 | ||||
| -rw-r--r-- | sharing-recovery.cabal | 11 | ||||
| -rw-r--r-- | src/Data/Expr/SharingRecovery.hs | 189 | ||||
| -rw-r--r-- | test/Main.hs | 114 | 
4 files changed, 271 insertions, 44 deletions
@@ -1 +1,2 @@  dist-newstyle/ +cabal.project.local diff --git a/sharing-recovery.cabal b/sharing-recovery.cabal index a994e03..df63c42 100644 --- a/sharing-recovery.cabal +++ b/sharing-recovery.cabal @@ -8,6 +8,7 @@ build-type:      Simple  library    exposed-modules:      Data.Expr.SharingRecovery +  other-modules:      Data.StableName.Extra    build-depends:      base >=4.16, @@ -18,3 +19,13 @@ library    hs-source-dirs: src    default-language: Haskell2010    ghc-options: -Wall + +test-suite test +  type: exitcode-stdio-1.0 +  main-is: Main.hs +  hs-source-dirs: test +  build-depends: +    sharing-recovery, +    base, +  default-language: Haskell2010 +  ghc-options: -Wall diff --git a/src/Data/Expr/SharingRecovery.hs b/src/Data/Expr/SharingRecovery.hs index e386f4e..cdb64eb 100644 --- a/src/Data/Expr/SharingRecovery.hs +++ b/src/Data/Expr/SharingRecovery.hs @@ -1,17 +1,27 @@ +{-# LANGUAGE BangPatterns #-}  {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DefaultSignatures #-}  {-# LANGUAGE DeriveFunctor #-}  {-# LANGUAGE DerivingVia #-} +{-# LANGUAGE FlexibleInstances #-}  {-# LANGUAGE GADTs #-}  {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuantifiedConstraints #-}  {-# LANGUAGE RankNTypes #-}  {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-}  {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-}  module Data.Expr.SharingRecovery where  import Control.Applicative ((<|>))  import Control.Monad.Trans.State.Strict  import Data.Bifunctor (second) -import Data.GADT.Compare +import Data.Char (chr, ord) +import Data.Functor.Const +import Data.Functor.Identity +import Data.Functor.Product  import Data.Hashable  import Data.HashMap.Strict (HashMap)  import qualified Data.HashMap.Strict as HM @@ -24,8 +34,6 @@ import Unsafe.Coerce (unsafeCoerce)  import Data.StableName.Extra --- TODO: This is not yet done, see the bottom of this file -  -- TODO: This implementation needs extensive documentation. 1. It is written  -- quite generically, meaning that the actual algorithm is easily obscured to  -- all but the most willing readers; 2. the original paper leaves something to @@ -36,24 +44,28 @@ import Data.StableName.Extra  class Functor1 f where    fmap1 :: (forall b. g b -> h b) -> f g a -> f h a +  default fmap1 :: Traversable1 f => (forall b. g b -> h b) -> f g a -> f h a +  fmap1 f x = runIdentity (traverse1 (Identity . f) x) +  class Functor1 f => Traversable1 f where    traverse1 :: Applicative m => (forall b. g b -> m (h b)) -> f g a -> m (f h a)  -- | Expression in parametric higher-order abstract syntax form  data PHOASExpr typ v f t where    PHOASOp :: typ t -> f (PHOASExpr typ v f) t -> PHOASExpr typ v f t -  PHOASLam :: typ (a -> b) -> typ a -> (PHOASExpr typ v f a -> PHOASExpr typ v f b) -> PHOASExpr typ v f (a -> b) +  PHOASLam :: typ (a -> b) -> typ a -> (v a -> PHOASExpr typ v f b) -> PHOASExpr typ v f (a -> b)    PHOASVar :: typ t -> v t -> PHOASExpr typ v f t  newtype Tag t = Tag Natural    deriving (Show, Eq) +  deriving (Hashable) via Natural  newtype NameFor typ f t = NameFor (StableName (PHOASExpr typ Tag f t))    deriving (Eq)    deriving (Hashable) via (StableName (PHOASExpr typ Tag f t)) -instance GEq (NameFor typ f) where -  geq (NameFor n1) (NameFor n2) +instance TestEquality (NameFor typ f) where +  testEquality (NameFor n1) (NameFor n2)      | eqStableName n1 n2 = Just unsafeCoerceRefl      | otherwise = Nothing      where @@ -106,7 +118,7 @@ pruneExpr' = \case        then pure $ PStub (NameFor name) tyf        else do          tag <- state (\(i, mp) -> (Tag i, (i + 1, mp))) -        let body = f (PHOASVar tyarg tag) +        let body = f tag          PLam (NameFor name) tyf tyarg tag <$> pruneExpr' body    PHOASVar ty tag -> pure $ PVar ty tag @@ -181,56 +193,145 @@ liftExpr' totals term =             | otherwise -> error "Term does not exist, yet we have it in hand" --- | Errors on stubs. -lexprTypeOf :: LExpr typ f t -> typ t -lexprTypeOf (LExpr _ e) = case e of -  LStub{} -> error "lexprTypeOf: got a stub" -  LOp _ t _ -> t -  LLam _ t _ _ _ -> t -  LVar t _ -> t - - --- TODO: lower LExpr into a normal expression with let bindings. Every LStub --- should correspond to some let-bound expression higher up in the tree (if it --- does not, that's a bug), and should become a De Bruijn variable reference to --- said let-bound expression. Lambdas should also get proper De Bruijn indices --- instead of tags, and LVar is also a normal variable (referring to a --- lambda-abstracted argument). -  -- | Untyped De Bruijn expression. No more names: there are lets now, and  -- variable references are De Bruijn indices. These indices are not type-safe  -- yet, though.  data UBExpr typ f t where    UBOp :: typ t -> f (UBExpr typ f) t -> UBExpr typ f t -  UBLam :: typ a -> UBExpr typ f b -> UBExpr typ f (a -> b) +  UBLam :: typ (a -> b) -> typ a -> UBExpr typ f b -> UBExpr typ f (a -> b)    UBLet :: typ a -> UBExpr typ f a -> UBExpr typ f b -> UBExpr typ f b -  -- De Bruijn index +  -- | De Bruijn index    UBVar :: typ t -> Int -> UBExpr typ f t -lowerExpr :: LExpr typ f t -> UBExpr typ f t -lowerExpr = lowerExpr' mempty 0 +lowerExpr :: Functor1 f => LExpr typ f t -> UBExpr typ f t +lowerExpr = lowerExpr' mempty mempty 0 --- 1. name |-> De Bruijn level of the variable defining that name --- 2. Number of variables already in scope -lowerExpr' :: forall typ f t. Traversable1 f => HashMap (SomeNameFor typ f) Int -> Int -> LExpr typ f t -> UBExpr typ f t -lowerExpr' namelvl curlvl (LExpr lifted ex) = -  let prefix = buildPrefix curlvl lifted +data SomeTag = forall t. SomeTag (Tag t) + +instance Eq SomeTag where +  SomeTag (Tag n) == SomeTag (Tag m) = n == m + +instance Hashable SomeTag where +  hashWithSalt salt (SomeTag tag) = hashWithSalt salt tag + +lowerExpr' :: forall typ f t. Functor1 f +           => HashMap (SomeNameFor typ f) Int  -- ^ node |-> De Bruijn level of defining binding +           -> HashMap SomeTag Int  -- ^ tag |-> De Bruijn level of defining binding +           -> Int  -- ^ Number of variables already in scope +           -> LExpr typ f t -> UBExpr typ f t +lowerExpr' namelvl taglvl curlvl (LExpr lifted ex) = +  let (namelvl', prefix) = buildPrefix namelvl curlvl lifted        curlvl' = curlvl + length lifted -  in case ex of -       LStub name ty -> -         case HM.lookup (SomeNameFor name) namelvl of -           Just lvl -> UBVar ty (curlvl - lvl - 1) -           Nothing -> error "Variable out of scope" -       LOp name ty args -> -         UBOp ty (_ $ traverse1 _ args) +  in prefix $ +       case ex of +         LStub name ty -> +           case HM.lookup (SomeNameFor name) namelvl' of +             Just lvl -> UBVar ty (curlvl - lvl - 1) +             Nothing -> error "Name variable out of scope" +         LOp _ ty args -> +           UBOp ty (fmap1 (lowerExpr' namelvl' taglvl curlvl') args) +         LLam _ tyf tyarg tag body -> +           UBLam tyf tyarg (lowerExpr' namelvl' (HM.insert (SomeTag tag) curlvl' taglvl) (curlvl' + 1) body) +         LVar ty tag -> +           case HM.lookup (SomeTag tag) taglvl of +             Just lvl -> UBVar ty (curlvl - lvl - 1) +             Nothing -> error "Tag variable out of scope"    where -    buildPrefix :: forall b. Int -> [Some (LExpr typ f)] -> UBExpr typ f b -> UBExpr typ f b -    buildPrefix _ [] = id -    buildPrefix lvl (Some rhs : rhss) = -      UBLet (lexprTypeOf rhs) (lowerExpr' namelvl lvl rhs) -      . buildPrefix (lvl + 1) rhss +    buildPrefix :: forall b. +                   HashMap (SomeNameFor typ f) Int +                -> Int +                -> [Some (LExpr typ f)] +                -> (HashMap (SomeNameFor typ f) Int, UBExpr typ f b -> UBExpr typ f b) +    buildPrefix namelvl' _ [] = (namelvl', id) +    buildPrefix namelvl' lvl (Some rhs@(LExpr _ rhs') : rhss) = +      let name = case rhs' of +                   LStub n _ -> n +                   LOp n _ _ -> n +                   LLam n _ _ _ _ -> n +                   LVar _ _ -> error "Recovering sharing of a tag is useless" +          ty = case rhs' of +                 LStub{} -> error "Recovering sharing of a stub is useless" +                 LOp _ t _ -> t +                 LLam _ t _ _ _ -> t +                 LVar t _ -> t +          prefix = UBLet ty (lowerExpr' namelvl' taglvl lvl rhs) +      in (prefix .) <$> buildPrefix (HM.insert (SomeNameFor name) lvl namelvl') (lvl + 1) rhss +-- | A typed De Bruijn index.  data Idx env t where    IZ :: Idx (t : env) t    IS :: Idx env t -> Idx (s : env) t +deriving instance Show (Idx env t) + +data Env env f where +  ETop :: Env '[] f +  EPush :: Env env f -> f t -> Env (t : env) f + +envLookup :: Idx env t -> Env env f -> f t +envLookup IZ (EPush _ x) = x +envLookup (IS i) (EPush e _) = envLookup i e + +-- | Untyped lookup in an 'Env'. +envLookupU :: Int -> Env env f -> Maybe (Some (Product f (Idx env))) +envLookupU = go id +  where +    go :: (forall a. Idx env a -> Idx env' a) -> Int -> Env env f -> Maybe (Some (Product f (Idx env'))) +    go !_ !_ ETop = Nothing +    go f 0 (EPush _ t) = Just (Some (Pair t (f IZ))) +    go f i (EPush e _) = go (f . IS) (i - 1) e + +-- | Typed De Bruijn expression. This is the final result of sharing recovery. +data BExpr typ env f t where +  BOp :: typ t -> f (BExpr typ env f) t -> BExpr typ env f t +  BLam :: typ (a -> b) -> typ a -> BExpr typ (a : env) f b -> BExpr typ env f (a -> b) +  BLet :: typ a -> BExpr typ env f a -> BExpr typ (a : env) f b -> BExpr typ env f b +  BVar :: typ t -> Idx env t -> BExpr typ env f t +deriving instance (forall a. Show (typ a), forall a r. (forall b. Show (r b)) => Show (f r a)) +               => Show (BExpr typ env f t) + +prettyBExpr :: (forall m env' a. Monad m => (forall b. Int -> BExpr typ env' f b -> m ShowS) +                                         -> Int -> f (BExpr typ env' f) a -> m ShowS) +            -> BExpr typ '[] f t -> String +prettyBExpr prettyOp e = evalState (prettyBExpr' prettyOp ETop 0 e) 0 "" + +prettyBExpr' :: (forall m env' a. Monad m => (forall b. Int -> BExpr typ env' f b -> m ShowS) +                                          -> Int -> f (BExpr typ env' f) a -> m ShowS) +             -> Env env (Const String) -> Int -> BExpr typ env f t -> State Int ShowS +prettyBExpr' prettyOp env d = \case +  BOp _ args -> +    prettyOp (prettyBExpr' prettyOp env) d args +  BLam _ _ body -> do +    name <- genName +    body' <- prettyBExpr' prettyOp (EPush env (Const name)) 0 body +    return $ showParen (d > 0) $ showString ("λ" ++ name ++ ". ") . body' +  BLet _ rhs body -> do +    name <- genName +    rhs' <- prettyBExpr' prettyOp env 0 rhs +    body' <- prettyBExpr' prettyOp (EPush env (Const name)) 0 body +    return $ showParen (d > 0) $ showString ("let " ++ name ++ " = ") . rhs' . showString " in " . body' +  BVar _ idx -> +    return $ showString (getConst (envLookup idx env)) +  where +    genName = do +      i <- state (\i -> (i, i + 1)) +      return $ if i < 26 then [chr (ord 'a' + i)] else 'x' : show i + +retypeExpr :: (Functor1 f, TestEquality typ) => UBExpr typ f t -> BExpr typ '[] f t +retypeExpr = retypeExpr' ETop + +retypeExpr' :: (Functor1 f, TestEquality typ) => Env env typ -> UBExpr typ f t -> BExpr typ env f t +retypeExpr' env (UBOp ty args) = BOp ty (fmap1 (retypeExpr' env) args) +retypeExpr' env (UBLam tyf tyarg body) = BLam tyf tyarg (retypeExpr' (EPush env tyarg) body) +retypeExpr' env (UBLet ty rhs body) = BLet ty (retypeExpr' env rhs) (retypeExpr' (EPush env ty) body) +retypeExpr' env (UBVar ty idx) = +  case envLookupU idx env of +    Just (Some (Pair defty tidx)) -> +      case testEquality ty defty of +        Just Refl -> BVar ty tidx +        Nothing -> error "Type mismatch in untyped De Bruijn expression" +    Nothing -> error "Untyped De Bruijn index out of range" + + +sharingRecovery :: (Traversable1 f, TestEquality typ) => (forall v. PHOASExpr typ v f t) -> BExpr typ '[] f t +sharingRecovery e = retypeExpr $ lowerExpr $ uncurry liftExpr $ pruneExpr e diff --git a/test/Main.hs b/test/Main.hs new file mode 100644 index 0000000..e7b303b --- /dev/null +++ b/test/Main.hs @@ -0,0 +1,114 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} +module Main where + +import Data.Type.Equality + +import Data.Expr.SharingRecovery + + +data Typ t where +  TInt :: Typ Int +  TBool :: Typ Bool +  TPair :: Typ a -> Typ b -> Typ (a, b) +  TFun :: Typ a -> Typ b -> Typ (a -> b) +deriving instance Show (Typ t) + +instance TestEquality Typ where +  testEquality TInt TInt = Just Refl +  testEquality TBool TBool = Just Refl +  testEquality (TPair a b) (TPair a' b') +    | Just Refl <- testEquality a a' +    , Just Refl <- testEquality b b' +    = Just Refl +  testEquality (TFun a b) (TFun a' b') +    | Just Refl <- testEquality a a' +    , Just Refl <- testEquality b b' +    = Just Refl +  testEquality _ _ = Nothing + +class KnownType t where τ :: Typ t +instance KnownType Int where τ = TInt +instance KnownType Bool where τ = TBool +instance (KnownType a, KnownType b) => KnownType (a, b) where τ = TPair τ τ +instance (KnownType a, KnownType b) => KnownType (a -> b) where τ = TFun τ τ + +data PrimOp a b where +  POAddI :: PrimOp (Int, Int) Int +  POMulI :: PrimOp (Int, Int) Int +  POEqI :: PrimOp (Int, Int) Bool +deriving instance Show (PrimOp a b) + +data Fixity = Infix | Prefix +  deriving (Show) + +primOpPrec :: PrimOp a b -> (Int, (Int, Int)) +primOpPrec POAddI = (6, (6, 7)) +primOpPrec POMulI = (7, (7, 8)) +primOpPrec POEqI = (4, (5, 5)) + +prettyPrimOp :: Fixity -> PrimOp a b -> ShowS +prettyPrimOp fix op = +  let s = case op of +            POAddI -> "+" +            POMulI -> "*" +            POEqI -> "==" +  in showString $ case fix of +       Infix -> s +       Prefix -> "(" ++ s ++ ")" + +data ArithF r t where +  A_Prim :: PrimOp a b -> r a -> ArithF r b +  A_Pair :: r a -> r b -> ArithF r (a, b) +  A_If :: r Bool -> r a -> r a -> ArithF r a +deriving instance (forall a. Show (r a)) => Show (ArithF r t) + +instance Functor1 ArithF +instance Traversable1 ArithF where +  traverse1 f (A_Prim op x) = A_Prim op <$> f x +  traverse1 f (A_Pair x y) = A_Pair <$> f x <*> f y +  traverse1 f (A_If x y z) = A_If <$> f x <*> f y <*> f z + +prettyArithF :: Monad m +             => (forall a. Int -> BExpr typ env ArithF a -> m ShowS) +             -> Int -> ArithF (BExpr typ env ArithF) t -> m ShowS +prettyArithF pr d = \case +  A_Prim op (BOp _ (A_Pair a b)) -> do +    let (dop, (dopL, dopR)) = primOpPrec op +    a' <- pr dopL a +    b' <- pr dopR b +    return $ showParen (d > dop) $ a' . showString " " . prettyPrimOp Infix op . showString " " . b' +  A_Prim op arg -> do +    arg' <- pr 11 arg +    return $ showParen (d > 10) $ prettyPrimOp Prefix op . showString " " . arg' +  A_Pair a b -> do +    a' <- pr 0 a +    b' <- pr 0 b +    return $ showString "(" . a' . showString ", " . b' . showString ")" +  A_If a b c -> do +    a' <- pr 0 a +    b' <- pr 0 b +    c' <- pr 0 c +    return $ showParen (d > 0) $ showString "if " . a' . showString " then " . b' . showString " else " . c' + +-- λx. x + x +ea_1 :: PHOASExpr Typ v ArithF (Int -> Int) +ea_1 = +  PHOASLam τ τ $ \arg -> +    PHOASOp τ (A_Prim POAddI +      (PHOASOp τ (A_Pair (PHOASVar τ arg) (PHOASVar τ arg)))) + +-- λx. let y = x + x in y * y +ea_2 :: PHOASExpr Typ v ArithF (Int -> Int) +ea_2 = +  PHOASLam τ τ $ \arg -> +    let y = PHOASOp τ (A_Prim POAddI +              (PHOASOp τ (A_Pair (PHOASVar τ arg) (PHOASVar τ arg)))) +    in PHOASOp τ (A_Prim POMulI +         (PHOASOp τ (A_Pair y y))) + +main :: IO () +main = putStrLn $ prettyBExpr prettyArithF (sharingRecovery ea_2)  | 
