diff options
| -rw-r--r-- | src/Data/Expr/SharingRecovery.hs | 207 | 
1 files changed, 143 insertions, 64 deletions
diff --git a/src/Data/Expr/SharingRecovery.hs b/src/Data/Expr/SharingRecovery.hs index 118df1c..e386f4e 100644 --- a/src/Data/Expr/SharingRecovery.hs +++ b/src/Data/Expr/SharingRecovery.hs @@ -1,8 +1,10 @@ +{-# LANGUAGE DataKinds #-}  {-# LANGUAGE DeriveFunctor #-}  {-# LANGUAGE DerivingVia #-}  {-# LANGUAGE GADTs #-}  {-# LANGUAGE LambdaCase #-}  {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE TypeOperators #-}  module Data.Expr.SharingRecovery where @@ -22,6 +24,15 @@ import Unsafe.Coerce (unsafeCoerce)  import Data.StableName.Extra +-- TODO: This is not yet done, see the bottom of this file + +-- TODO: This implementation needs extensive documentation. 1. It is written +-- quite generically, meaning that the actual algorithm is easily obscured to +-- all but the most willing readers; 2. the original paper leaves something to +-- be desired in the domain of effective explanation for programmers, and this +-- is a good opportunity to try to do better. + +  class Functor1 f where    fmap1 :: (forall b. g b -> h b) -> f g a -> f h a @@ -29,19 +40,19 @@ class Functor1 f => Traversable1 f where    traverse1 :: Applicative m => (forall b. g b -> m (h b)) -> f g a -> m (f h a)  -- | Expression in parametric higher-order abstract syntax form -data PHOASExpr v f t where -  PHOASOp :: f (PHOASExpr v f) t -> PHOASExpr v f t -  PHOASLam :: (PHOASExpr v f a -> PHOASExpr v f b) -> PHOASExpr v f (a -> b) -  PHOASVar :: v t -> PHOASExpr v f t +data PHOASExpr typ v f t where +  PHOASOp :: typ t -> f (PHOASExpr typ v f) t -> PHOASExpr typ v f t +  PHOASLam :: typ (a -> b) -> typ a -> (PHOASExpr typ v f a -> PHOASExpr typ v f b) -> PHOASExpr typ v f (a -> b) +  PHOASVar :: typ t -> v t -> PHOASExpr typ v f t  newtype Tag t = Tag Natural    deriving (Show, Eq) -newtype NameFor f t = NameFor (StableName (PHOASExpr Tag f t)) +newtype NameFor typ f t = NameFor (StableName (PHOASExpr typ Tag f t))    deriving (Eq) -  deriving (Hashable) via (StableName (f (PHOASExpr Tag f) t)) +  deriving (Hashable) via (StableName (PHOASExpr typ Tag f t)) -instance GEq (NameFor f) where +instance GEq (NameFor typ f) where    geq (NameFor n1) (NameFor n2)      | eqStableName n1 n2 = Just unsafeCoerceRefl      | otherwise = Nothing @@ -49,89 +60,110 @@ instance GEq (NameFor f) where        unsafeCoerceRefl :: a :~: b  -- restricted version of unsafeCoerce that only allows punting proofs        unsafeCoerceRefl = unsafeCoerce Refl --- | Pruned expression -data PExpr f t where -  PStub :: NameFor f t -> PExpr f t -  POp :: NameFor f t -> f (PExpr f) t -> PExpr f t -  PLam :: NameFor f (a -> b) -> Tag a -> PExpr f b -> PExpr f (a -> b) -  PVar :: Tag a -> PExpr f a +-- | Pruned expression. +-- +-- Note that variables do not, and will never, have a name: we don't bother +-- detecting sharing for variable references, because that would only introduce +-- a redundant variable indirection. +data PExpr typ f t where +  PStub :: NameFor typ f t -> typ t -> PExpr typ f t +  POp :: NameFor typ f t -> typ t -> f (PExpr typ f) t -> PExpr typ f t +  PLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> PExpr typ f b -> PExpr typ f (a -> b) +  PVar :: typ a -> Tag a -> PExpr typ f a -data SomeNameFor f = forall t. SomeNameFor {-# UNPACK #-} !(NameFor f t) +data SomeNameFor typ f = forall t. SomeNameFor {-# UNPACK #-} !(NameFor typ f t) -instance Eq (SomeNameFor f) where +instance Eq (SomeNameFor typ f) where    SomeNameFor (NameFor n1) == SomeNameFor (NameFor n2) = eqStableName n1 n2 -instance Hashable (SomeNameFor f) where +instance Hashable (SomeNameFor typ f) where    hashWithSalt salt (SomeNameFor name) = hashWithSalt salt name -type OccMap f = HashMap (SomeNameFor f) Natural +-- | The number of times a particular name is visited in a preorder traversal +-- of the PHOAS expression, excluding children of nodes upon second or later +-- visit. That is to say: only the nodes that are visited in a preorder +-- traversal that skips repeated subtrees, are counted. +type OccMap typ f = HashMap (SomeNameFor typ f) Natural -pruneExpr :: Traversable1 f => (forall v. PHOASExpr v f t) -> (OccMap f, PExpr f t) +pruneExpr :: Traversable1 f => (forall v. PHOASExpr typ v f t) -> (OccMap typ f, PExpr typ f t)  pruneExpr term =    let (term', (_, mp)) = runState (pruneExpr' term) (0, mempty)    in (mp, term') -pruneExpr' :: Traversable1 f => PHOASExpr Tag f t -> State (Natural, OccMap f) (PExpr f t) -pruneExpr' orig@(PHOASOp args) = do -  let name = makeStableName' orig -  occmap <- gets snd -  let (seenBefore, occmap') = -        HM.alterF (\case Nothing -> (False, Just 1) -                         Just n -> (True, Just (n + 1))) -                  (SomeNameFor (NameFor name)) -                  occmap -  modify (second (const occmap')) -  if seenBefore -    then pure $ PStub (NameFor name) -    else POp (NameFor name) <$> traverse1 pruneExpr' args +pruneExpr' :: Traversable1 f => PHOASExpr typ Tag f t -> State (Natural, OccMap typ f) (PExpr typ f t) +pruneExpr' = \case +  orig@(PHOASOp ty args) -> do +    let name = makeStableName' orig +    seenBefore <- checkVisited name +    if seenBefore +      then pure $ PStub (NameFor name) ty +      else POp (NameFor name) ty <$> traverse1 pruneExpr' args -pruneExpr' orig@(PHOASLam f) = do -  let name = makeStableName' orig -  tag <- state (\(i, mp) -> (Tag i, (i + 1, mp))) -  let body = f (PHOASVar tag) -  PLam (NameFor name) tag <$> pruneExpr' body +  orig@(PHOASLam tyf tyarg f) -> do +    let name = makeStableName' orig +    seenBefore <- checkVisited name +    if seenBefore +      then pure $ PStub (NameFor name) tyf +      else do +        tag <- state (\(i, mp) -> (Tag i, (i + 1, mp))) +        let body = f (PHOASVar tyarg tag) +        PLam (NameFor name) tyf tyarg tag <$> pruneExpr' body -pruneExpr' (PHOASVar tag) = pure $ PVar tag +  PHOASVar ty tag -> pure $ PVar ty tag +  where +    checkVisited name = do +      occmap <- gets snd +      let (seenBefore, occmap') = +            HM.alterF (\case Nothing -> (False, Just 1) +                             Just n -> (True, Just (n + 1))) +                      (SomeNameFor (NameFor name)) +                      occmap +      modify (second (const occmap')) +      return seenBefore --- | Lifted expression: a bunch of to-be let bound expressions on top of an LExpr' -data LExpr f t = LExpr [Some (LExpr f)] (LExpr' f t) -data LExpr' f t where  -- TODO: this could be an instantiation of (a generalisation of) PExpr -  LStub :: NameFor f t -> LExpr' f t -  LOp :: NameFor f t -> f (LExpr f) t -> LExpr' f t -  LLam :: NameFor f (a -> b) -> Tag a -> LExpr f b -> LExpr' f (a -> b) -  LVar :: Tag a -> LExpr' f a +-- | Lifted expression: a bunch of to-be let bound expressions on top of an +-- LExpr'. Because LExpr' is really just PExpr with the recursive positions +-- replaced by LExpr, LExpr should be seen as PExpr with a bunch of to-be let +-- bound expressions on top of every node. +data LExpr typ f t = LExpr [Some (LExpr typ f)] (LExpr' typ f t) +data LExpr' typ f t where  -- TODO: this could be an instantiation of (a generalisation of) PExpr +  LStub :: NameFor typ f t -> typ t -> LExpr' typ f t +  LOp :: NameFor typ f t -> typ t -> f (LExpr typ f) t -> LExpr' typ f t +  LLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> LExpr typ f b -> LExpr' typ f (a -> b) +  LVar :: typ a -> Tag a -> LExpr' typ f a -liftExpr :: Traversable1 f => OccMap f -> PExpr f t -> LExpr f t -liftExpr totals term = -  let (_, e) = liftExpr' totals term -  in e +liftExpr :: Traversable1 f => OccMap typ f -> PExpr typ f t -> LExpr typ f t +liftExpr totals term = snd (liftExpr' totals term) -newtype FoundMap f = FoundMap -  (HashMap (SomeNameFor f) (Natural  -- how many times seen -                           ,Maybe (Some (LExpr f))))  -- the lifted subterm (once seen) +newtype FoundMap typ f = FoundMap +  (HashMap (SomeNameFor typ f) (Natural  -- how many times seen +                               ,Maybe (Some (LExpr typ f))))  -- the lifted subterm (once seen) -instance Semigroup (FoundMap f) where +instance Semigroup (FoundMap typ f) where    FoundMap m1 <> FoundMap m2 = FoundMap $      HM.unionWith (\(n1, me1) (n2, me2) -> (n1 + n2, me1 <|> me2)) m1 m2 -instance Monoid (FoundMap f) where +instance Monoid (FoundMap typ f) where    mempty = FoundMap HM.empty -liftExpr' :: Traversable1 f => OccMap f -> PExpr f t -> (FoundMap f, LExpr f t) -liftExpr' _totals (PStub name) = -  (FoundMap $ HM.singleton (SomeNameFor name) (1, Just (Some (LExpr [] (LStub name)))) -  ,LExpr [] (LStub name)) +liftExpr' :: Traversable1 f => OccMap typ f -> PExpr typ f t -> (FoundMap typ f, LExpr typ f t) +liftExpr' _totals (PStub name ty) = +  (FoundMap $ HM.singleton (SomeNameFor name) (1, Nothing)  -- Just (Some (LExpr [] (LStub name))) +  ,LExpr [] (LStub name ty)) -liftExpr' _totals (PVar tag) = (mempty, LExpr [] (LVar tag)) +liftExpr' _totals (PVar ty tag) = (mempty, LExpr [] (LVar ty tag))  liftExpr' totals term = -  let (FoundMap foundmap, name, term') = case term of -        POp n args -> let (fm, args') = traverse1 (liftExpr' totals) args -                      in (fm, n, LOp n args') -        PLam n tag body -> let (fm, body') = liftExpr' totals body -                           in (fm, n, LLam n tag body') +  let (FoundMap foundmap, name, termty, term') = case term of +        POp n ty args -> +          let (fm, args') = traverse1 (liftExpr' totals) args +          in (fm, n, ty, LOp n ty args') +        PLam n tyf tyarg tag body -> +          let (fm, body') = liftExpr' totals body +          in (fm, n, tyf, LLam n tyf tyarg tag body') +      -- TODO: perhaps this HM.toList together with the foldr HM.delete can be a single traversal of the HashMap        saturated = [case mterm of                       Just t -> (nm, t)                       Nothing -> error "Name saturated but no term found" @@ -145,13 +177,60 @@ liftExpr' totals term =    in case HM.findWithDefault 0 (SomeNameFor name) totals of         1 -> (FoundMap foundmap', lterm)         tot | tot > 1 -> (FoundMap (HM.insert (SomeNameFor name) (1, Just (Some lterm)) foundmap') -                        ,LExpr [] (LStub name)) +                        ,LExpr [] (LStub name termty))             | otherwise -> error "Term does not exist, yet we have it in hand" +-- | Errors on stubs. +lexprTypeOf :: LExpr typ f t -> typ t +lexprTypeOf (LExpr _ e) = case e of +  LStub{} -> error "lexprTypeOf: got a stub" +  LOp _ t _ -> t +  LLam _ t _ _ _ -> t +  LVar t _ -> t + +  -- TODO: lower LExpr into a normal expression with let bindings. Every LStub  -- should correspond to some let-bound expression higher up in the tree (if it  -- does not, that's a bug), and should become a De Bruijn variable reference to  -- said let-bound expression. Lambdas should also get proper De Bruijn indices  -- instead of tags, and LVar is also a normal variable (referring to a  -- lambda-abstracted argument). + +-- | Untyped De Bruijn expression. No more names: there are lets now, and +-- variable references are De Bruijn indices. These indices are not type-safe +-- yet, though. +data UBExpr typ f t where +  UBOp :: typ t -> f (UBExpr typ f) t -> UBExpr typ f t +  UBLam :: typ a -> UBExpr typ f b -> UBExpr typ f (a -> b) +  UBLet :: typ a -> UBExpr typ f a -> UBExpr typ f b -> UBExpr typ f b +  -- De Bruijn index +  UBVar :: typ t -> Int -> UBExpr typ f t + +lowerExpr :: LExpr typ f t -> UBExpr typ f t +lowerExpr = lowerExpr' mempty 0 + +-- 1. name |-> De Bruijn level of the variable defining that name +-- 2. Number of variables already in scope +lowerExpr' :: forall typ f t. Traversable1 f => HashMap (SomeNameFor typ f) Int -> Int -> LExpr typ f t -> UBExpr typ f t +lowerExpr' namelvl curlvl (LExpr lifted ex) = +  let prefix = buildPrefix curlvl lifted +      curlvl' = curlvl + length lifted +  in case ex of +       LStub name ty -> +         case HM.lookup (SomeNameFor name) namelvl of +           Just lvl -> UBVar ty (curlvl - lvl - 1) +           Nothing -> error "Variable out of scope" +       LOp name ty args -> +         UBOp ty (_ $ traverse1 _ args) +  where +    buildPrefix :: forall b. Int -> [Some (LExpr typ f)] -> UBExpr typ f b -> UBExpr typ f b +    buildPrefix _ [] = id +    buildPrefix lvl (Some rhs : rhss) = +      UBLet (lexprTypeOf rhs) (lowerExpr' namelvl lvl rhs) +      . buildPrefix (lvl + 1) rhss + + +data Idx env t where +  IZ :: Idx (t : env) t +  IS :: Idx env t -> Idx (s : env) t  | 
