From a6f925deaa2044a0fe18a74fc52703e00f111056 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Tue, 27 Aug 2024 23:00:13 +0200 Subject: WIP lower to de bruijn indices --- src/Data/Expr/SharingRecovery.hs | 223 ++++++++++++++++++++++++++------------- 1 file changed, 151 insertions(+), 72 deletions(-) (limited to 'src/Data/Expr/SharingRecovery.hs') diff --git a/src/Data/Expr/SharingRecovery.hs b/src/Data/Expr/SharingRecovery.hs index 118df1c..e386f4e 100644 --- a/src/Data/Expr/SharingRecovery.hs +++ b/src/Data/Expr/SharingRecovery.hs @@ -1,8 +1,10 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} module Data.Expr.SharingRecovery where @@ -22,6 +24,15 @@ import Unsafe.Coerce (unsafeCoerce) import Data.StableName.Extra +-- TODO: This is not yet done, see the bottom of this file + +-- TODO: This implementation needs extensive documentation. 1. It is written +-- quite generically, meaning that the actual algorithm is easily obscured to +-- all but the most willing readers; 2. the original paper leaves something to +-- be desired in the domain of effective explanation for programmers, and this +-- is a good opportunity to try to do better. + + class Functor1 f where fmap1 :: (forall b. g b -> h b) -> f g a -> f h a @@ -29,19 +40,19 @@ class Functor1 f => Traversable1 f where traverse1 :: Applicative m => (forall b. g b -> m (h b)) -> f g a -> m (f h a) -- | Expression in parametric higher-order abstract syntax form -data PHOASExpr v f t where - PHOASOp :: f (PHOASExpr v f) t -> PHOASExpr v f t - PHOASLam :: (PHOASExpr v f a -> PHOASExpr v f b) -> PHOASExpr v f (a -> b) - PHOASVar :: v t -> PHOASExpr v f t +data PHOASExpr typ v f t where + PHOASOp :: typ t -> f (PHOASExpr typ v f) t -> PHOASExpr typ v f t + PHOASLam :: typ (a -> b) -> typ a -> (PHOASExpr typ v f a -> PHOASExpr typ v f b) -> PHOASExpr typ v f (a -> b) + PHOASVar :: typ t -> v t -> PHOASExpr typ v f t newtype Tag t = Tag Natural deriving (Show, Eq) -newtype NameFor f t = NameFor (StableName (PHOASExpr Tag f t)) +newtype NameFor typ f t = NameFor (StableName (PHOASExpr typ Tag f t)) deriving (Eq) - deriving (Hashable) via (StableName (f (PHOASExpr Tag f) t)) + deriving (Hashable) via (StableName (PHOASExpr typ Tag f t)) -instance GEq (NameFor f) where +instance GEq (NameFor typ f) where geq (NameFor n1) (NameFor n2) | eqStableName n1 n2 = Just unsafeCoerceRefl | otherwise = Nothing @@ -49,89 +60,110 @@ instance GEq (NameFor f) where unsafeCoerceRefl :: a :~: b -- restricted version of unsafeCoerce that only allows punting proofs unsafeCoerceRefl = unsafeCoerce Refl --- | Pruned expression -data PExpr f t where - PStub :: NameFor f t -> PExpr f t - POp :: NameFor f t -> f (PExpr f) t -> PExpr f t - PLam :: NameFor f (a -> b) -> Tag a -> PExpr f b -> PExpr f (a -> b) - PVar :: Tag a -> PExpr f a +-- | Pruned expression. +-- +-- Note that variables do not, and will never, have a name: we don't bother +-- detecting sharing for variable references, because that would only introduce +-- a redundant variable indirection. +data PExpr typ f t where + PStub :: NameFor typ f t -> typ t -> PExpr typ f t + POp :: NameFor typ f t -> typ t -> f (PExpr typ f) t -> PExpr typ f t + PLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> PExpr typ f b -> PExpr typ f (a -> b) + PVar :: typ a -> Tag a -> PExpr typ f a -data SomeNameFor f = forall t. SomeNameFor {-# UNPACK #-} !(NameFor f t) +data SomeNameFor typ f = forall t. SomeNameFor {-# UNPACK #-} !(NameFor typ f t) -instance Eq (SomeNameFor f) where +instance Eq (SomeNameFor typ f) where SomeNameFor (NameFor n1) == SomeNameFor (NameFor n2) = eqStableName n1 n2 -instance Hashable (SomeNameFor f) where +instance Hashable (SomeNameFor typ f) where hashWithSalt salt (SomeNameFor name) = hashWithSalt salt name -type OccMap f = HashMap (SomeNameFor f) Natural +-- | The number of times a particular name is visited in a preorder traversal +-- of the PHOAS expression, excluding children of nodes upon second or later +-- visit. That is to say: only the nodes that are visited in a preorder +-- traversal that skips repeated subtrees, are counted. +type OccMap typ f = HashMap (SomeNameFor typ f) Natural -pruneExpr :: Traversable1 f => (forall v. PHOASExpr v f t) -> (OccMap f, PExpr f t) +pruneExpr :: Traversable1 f => (forall v. PHOASExpr typ v f t) -> (OccMap typ f, PExpr typ f t) pruneExpr term = let (term', (_, mp)) = runState (pruneExpr' term) (0, mempty) in (mp, term') -pruneExpr' :: Traversable1 f => PHOASExpr Tag f t -> State (Natural, OccMap f) (PExpr f t) -pruneExpr' orig@(PHOASOp args) = do - let name = makeStableName' orig - occmap <- gets snd - let (seenBefore, occmap') = - HM.alterF (\case Nothing -> (False, Just 1) - Just n -> (True, Just (n + 1))) - (SomeNameFor (NameFor name)) - occmap - modify (second (const occmap')) - if seenBefore - then pure $ PStub (NameFor name) - else POp (NameFor name) <$> traverse1 pruneExpr' args - -pruneExpr' orig@(PHOASLam f) = do - let name = makeStableName' orig - tag <- state (\(i, mp) -> (Tag i, (i + 1, mp))) - let body = f (PHOASVar tag) - PLam (NameFor name) tag <$> pruneExpr' body - -pruneExpr' (PHOASVar tag) = pure $ PVar tag - - --- | Lifted expression: a bunch of to-be let bound expressions on top of an LExpr' -data LExpr f t = LExpr [Some (LExpr f)] (LExpr' f t) -data LExpr' f t where -- TODO: this could be an instantiation of (a generalisation of) PExpr - LStub :: NameFor f t -> LExpr' f t - LOp :: NameFor f t -> f (LExpr f) t -> LExpr' f t - LLam :: NameFor f (a -> b) -> Tag a -> LExpr f b -> LExpr' f (a -> b) - LVar :: Tag a -> LExpr' f a - -liftExpr :: Traversable1 f => OccMap f -> PExpr f t -> LExpr f t -liftExpr totals term = - let (_, e) = liftExpr' totals term - in e - -newtype FoundMap f = FoundMap - (HashMap (SomeNameFor f) (Natural -- how many times seen - ,Maybe (Some (LExpr f)))) -- the lifted subterm (once seen) - -instance Semigroup (FoundMap f) where +pruneExpr' :: Traversable1 f => PHOASExpr typ Tag f t -> State (Natural, OccMap typ f) (PExpr typ f t) +pruneExpr' = \case + orig@(PHOASOp ty args) -> do + let name = makeStableName' orig + seenBefore <- checkVisited name + if seenBefore + then pure $ PStub (NameFor name) ty + else POp (NameFor name) ty <$> traverse1 pruneExpr' args + + orig@(PHOASLam tyf tyarg f) -> do + let name = makeStableName' orig + seenBefore <- checkVisited name + if seenBefore + then pure $ PStub (NameFor name) tyf + else do + tag <- state (\(i, mp) -> (Tag i, (i + 1, mp))) + let body = f (PHOASVar tyarg tag) + PLam (NameFor name) tyf tyarg tag <$> pruneExpr' body + + PHOASVar ty tag -> pure $ PVar ty tag + where + checkVisited name = do + occmap <- gets snd + let (seenBefore, occmap') = + HM.alterF (\case Nothing -> (False, Just 1) + Just n -> (True, Just (n + 1))) + (SomeNameFor (NameFor name)) + occmap + modify (second (const occmap')) + return seenBefore + + +-- | Lifted expression: a bunch of to-be let bound expressions on top of an +-- LExpr'. Because LExpr' is really just PExpr with the recursive positions +-- replaced by LExpr, LExpr should be seen as PExpr with a bunch of to-be let +-- bound expressions on top of every node. +data LExpr typ f t = LExpr [Some (LExpr typ f)] (LExpr' typ f t) +data LExpr' typ f t where -- TODO: this could be an instantiation of (a generalisation of) PExpr + LStub :: NameFor typ f t -> typ t -> LExpr' typ f t + LOp :: NameFor typ f t -> typ t -> f (LExpr typ f) t -> LExpr' typ f t + LLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> LExpr typ f b -> LExpr' typ f (a -> b) + LVar :: typ a -> Tag a -> LExpr' typ f a + +liftExpr :: Traversable1 f => OccMap typ f -> PExpr typ f t -> LExpr typ f t +liftExpr totals term = snd (liftExpr' totals term) + +newtype FoundMap typ f = FoundMap + (HashMap (SomeNameFor typ f) (Natural -- how many times seen + ,Maybe (Some (LExpr typ f)))) -- the lifted subterm (once seen) + +instance Semigroup (FoundMap typ f) where FoundMap m1 <> FoundMap m2 = FoundMap $ HM.unionWith (\(n1, me1) (n2, me2) -> (n1 + n2, me1 <|> me2)) m1 m2 -instance Monoid (FoundMap f) where +instance Monoid (FoundMap typ f) where mempty = FoundMap HM.empty -liftExpr' :: Traversable1 f => OccMap f -> PExpr f t -> (FoundMap f, LExpr f t) -liftExpr' _totals (PStub name) = - (FoundMap $ HM.singleton (SomeNameFor name) (1, Just (Some (LExpr [] (LStub name)))) - ,LExpr [] (LStub name)) +liftExpr' :: Traversable1 f => OccMap typ f -> PExpr typ f t -> (FoundMap typ f, LExpr typ f t) +liftExpr' _totals (PStub name ty) = + (FoundMap $ HM.singleton (SomeNameFor name) (1, Nothing) -- Just (Some (LExpr [] (LStub name))) + ,LExpr [] (LStub name ty)) -liftExpr' _totals (PVar tag) = (mempty, LExpr [] (LVar tag)) +liftExpr' _totals (PVar ty tag) = (mempty, LExpr [] (LVar ty tag)) liftExpr' totals term = - let (FoundMap foundmap, name, term') = case term of - POp n args -> let (fm, args') = traverse1 (liftExpr' totals) args - in (fm, n, LOp n args') - PLam n tag body -> let (fm, body') = liftExpr' totals body - in (fm, n, LLam n tag body') - + let (FoundMap foundmap, name, termty, term') = case term of + POp n ty args -> + let (fm, args') = traverse1 (liftExpr' totals) args + in (fm, n, ty, LOp n ty args') + PLam n tyf tyarg tag body -> + let (fm, body') = liftExpr' totals body + in (fm, n, tyf, LLam n tyf tyarg tag body') + + -- TODO: perhaps this HM.toList together with the foldr HM.delete can be a single traversal of the HashMap saturated = [case mterm of Just t -> (nm, t) Nothing -> error "Name saturated but no term found" @@ -145,13 +177,60 @@ liftExpr' totals term = in case HM.findWithDefault 0 (SomeNameFor name) totals of 1 -> (FoundMap foundmap', lterm) tot | tot > 1 -> (FoundMap (HM.insert (SomeNameFor name) (1, Just (Some lterm)) foundmap') - ,LExpr [] (LStub name)) + ,LExpr [] (LStub name termty)) | otherwise -> error "Term does not exist, yet we have it in hand" +-- | Errors on stubs. +lexprTypeOf :: LExpr typ f t -> typ t +lexprTypeOf (LExpr _ e) = case e of + LStub{} -> error "lexprTypeOf: got a stub" + LOp _ t _ -> t + LLam _ t _ _ _ -> t + LVar t _ -> t + + -- TODO: lower LExpr into a normal expression with let bindings. Every LStub -- should correspond to some let-bound expression higher up in the tree (if it -- does not, that's a bug), and should become a De Bruijn variable reference to -- said let-bound expression. Lambdas should also get proper De Bruijn indices -- instead of tags, and LVar is also a normal variable (referring to a -- lambda-abstracted argument). + +-- | Untyped De Bruijn expression. No more names: there are lets now, and +-- variable references are De Bruijn indices. These indices are not type-safe +-- yet, though. +data UBExpr typ f t where + UBOp :: typ t -> f (UBExpr typ f) t -> UBExpr typ f t + UBLam :: typ a -> UBExpr typ f b -> UBExpr typ f (a -> b) + UBLet :: typ a -> UBExpr typ f a -> UBExpr typ f b -> UBExpr typ f b + -- De Bruijn index + UBVar :: typ t -> Int -> UBExpr typ f t + +lowerExpr :: LExpr typ f t -> UBExpr typ f t +lowerExpr = lowerExpr' mempty 0 + +-- 1. name |-> De Bruijn level of the variable defining that name +-- 2. Number of variables already in scope +lowerExpr' :: forall typ f t. Traversable1 f => HashMap (SomeNameFor typ f) Int -> Int -> LExpr typ f t -> UBExpr typ f t +lowerExpr' namelvl curlvl (LExpr lifted ex) = + let prefix = buildPrefix curlvl lifted + curlvl' = curlvl + length lifted + in case ex of + LStub name ty -> + case HM.lookup (SomeNameFor name) namelvl of + Just lvl -> UBVar ty (curlvl - lvl - 1) + Nothing -> error "Variable out of scope" + LOp name ty args -> + UBOp ty (_ $ traverse1 _ args) + where + buildPrefix :: forall b. Int -> [Some (LExpr typ f)] -> UBExpr typ f b -> UBExpr typ f b + buildPrefix _ [] = id + buildPrefix lvl (Some rhs : rhss) = + UBLet (lexprTypeOf rhs) (lowerExpr' namelvl lvl rhs) + . buildPrefix (lvl + 1) rhss + + +data Idx env t where + IZ :: Idx (t : env) t + IS :: Idx env t -> Idx (s : env) t -- cgit v1.2.3-70-g09d2