aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Expr
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-08-27 23:00:13 +0200
committerTom Smeding <tom@tomsmeding.com>2024-08-27 23:00:13 +0200
commita6f925deaa2044a0fe18a74fc52703e00f111056 (patch)
tree510703e2ec030a24a668ab50dd7abc90062e95db /src/Data/Expr
parent250e3beae7a961fc740f775a563c303b4cc390fe (diff)
WIP lower to de bruijn indices
Diffstat (limited to 'src/Data/Expr')
-rw-r--r--src/Data/Expr/SharingRecovery.hs223
1 files changed, 151 insertions, 72 deletions
diff --git a/src/Data/Expr/SharingRecovery.hs b/src/Data/Expr/SharingRecovery.hs
index 118df1c..e386f4e 100644
--- a/src/Data/Expr/SharingRecovery.hs
+++ b/src/Data/Expr/SharingRecovery.hs
@@ -1,8 +1,10 @@
+{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
module Data.Expr.SharingRecovery where
@@ -22,6 +24,15 @@ import Unsafe.Coerce (unsafeCoerce)
import Data.StableName.Extra
+-- TODO: This is not yet done, see the bottom of this file
+
+-- TODO: This implementation needs extensive documentation. 1. It is written
+-- quite generically, meaning that the actual algorithm is easily obscured to
+-- all but the most willing readers; 2. the original paper leaves something to
+-- be desired in the domain of effective explanation for programmers, and this
+-- is a good opportunity to try to do better.
+
+
class Functor1 f where
fmap1 :: (forall b. g b -> h b) -> f g a -> f h a
@@ -29,19 +40,19 @@ class Functor1 f => Traversable1 f where
traverse1 :: Applicative m => (forall b. g b -> m (h b)) -> f g a -> m (f h a)
-- | Expression in parametric higher-order abstract syntax form
-data PHOASExpr v f t where
- PHOASOp :: f (PHOASExpr v f) t -> PHOASExpr v f t
- PHOASLam :: (PHOASExpr v f a -> PHOASExpr v f b) -> PHOASExpr v f (a -> b)
- PHOASVar :: v t -> PHOASExpr v f t
+data PHOASExpr typ v f t where
+ PHOASOp :: typ t -> f (PHOASExpr typ v f) t -> PHOASExpr typ v f t
+ PHOASLam :: typ (a -> b) -> typ a -> (PHOASExpr typ v f a -> PHOASExpr typ v f b) -> PHOASExpr typ v f (a -> b)
+ PHOASVar :: typ t -> v t -> PHOASExpr typ v f t
newtype Tag t = Tag Natural
deriving (Show, Eq)
-newtype NameFor f t = NameFor (StableName (PHOASExpr Tag f t))
+newtype NameFor typ f t = NameFor (StableName (PHOASExpr typ Tag f t))
deriving (Eq)
- deriving (Hashable) via (StableName (f (PHOASExpr Tag f) t))
+ deriving (Hashable) via (StableName (PHOASExpr typ Tag f t))
-instance GEq (NameFor f) where
+instance GEq (NameFor typ f) where
geq (NameFor n1) (NameFor n2)
| eqStableName n1 n2 = Just unsafeCoerceRefl
| otherwise = Nothing
@@ -49,89 +60,110 @@ instance GEq (NameFor f) where
unsafeCoerceRefl :: a :~: b -- restricted version of unsafeCoerce that only allows punting proofs
unsafeCoerceRefl = unsafeCoerce Refl
--- | Pruned expression
-data PExpr f t where
- PStub :: NameFor f t -> PExpr f t
- POp :: NameFor f t -> f (PExpr f) t -> PExpr f t
- PLam :: NameFor f (a -> b) -> Tag a -> PExpr f b -> PExpr f (a -> b)
- PVar :: Tag a -> PExpr f a
+-- | Pruned expression.
+--
+-- Note that variables do not, and will never, have a name: we don't bother
+-- detecting sharing for variable references, because that would only introduce
+-- a redundant variable indirection.
+data PExpr typ f t where
+ PStub :: NameFor typ f t -> typ t -> PExpr typ f t
+ POp :: NameFor typ f t -> typ t -> f (PExpr typ f) t -> PExpr typ f t
+ PLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> PExpr typ f b -> PExpr typ f (a -> b)
+ PVar :: typ a -> Tag a -> PExpr typ f a
-data SomeNameFor f = forall t. SomeNameFor {-# UNPACK #-} !(NameFor f t)
+data SomeNameFor typ f = forall t. SomeNameFor {-# UNPACK #-} !(NameFor typ f t)
-instance Eq (SomeNameFor f) where
+instance Eq (SomeNameFor typ f) where
SomeNameFor (NameFor n1) == SomeNameFor (NameFor n2) = eqStableName n1 n2
-instance Hashable (SomeNameFor f) where
+instance Hashable (SomeNameFor typ f) where
hashWithSalt salt (SomeNameFor name) = hashWithSalt salt name
-type OccMap f = HashMap (SomeNameFor f) Natural
+-- | The number of times a particular name is visited in a preorder traversal
+-- of the PHOAS expression, excluding children of nodes upon second or later
+-- visit. That is to say: only the nodes that are visited in a preorder
+-- traversal that skips repeated subtrees, are counted.
+type OccMap typ f = HashMap (SomeNameFor typ f) Natural
-pruneExpr :: Traversable1 f => (forall v. PHOASExpr v f t) -> (OccMap f, PExpr f t)
+pruneExpr :: Traversable1 f => (forall v. PHOASExpr typ v f t) -> (OccMap typ f, PExpr typ f t)
pruneExpr term =
let (term', (_, mp)) = runState (pruneExpr' term) (0, mempty)
in (mp, term')
-pruneExpr' :: Traversable1 f => PHOASExpr Tag f t -> State (Natural, OccMap f) (PExpr f t)
-pruneExpr' orig@(PHOASOp args) = do
- let name = makeStableName' orig
- occmap <- gets snd
- let (seenBefore, occmap') =
- HM.alterF (\case Nothing -> (False, Just 1)
- Just n -> (True, Just (n + 1)))
- (SomeNameFor (NameFor name))
- occmap
- modify (second (const occmap'))
- if seenBefore
- then pure $ PStub (NameFor name)
- else POp (NameFor name) <$> traverse1 pruneExpr' args
-
-pruneExpr' orig@(PHOASLam f) = do
- let name = makeStableName' orig
- tag <- state (\(i, mp) -> (Tag i, (i + 1, mp)))
- let body = f (PHOASVar tag)
- PLam (NameFor name) tag <$> pruneExpr' body
-
-pruneExpr' (PHOASVar tag) = pure $ PVar tag
-
-
--- | Lifted expression: a bunch of to-be let bound expressions on top of an LExpr'
-data LExpr f t = LExpr [Some (LExpr f)] (LExpr' f t)
-data LExpr' f t where -- TODO: this could be an instantiation of (a generalisation of) PExpr
- LStub :: NameFor f t -> LExpr' f t
- LOp :: NameFor f t -> f (LExpr f) t -> LExpr' f t
- LLam :: NameFor f (a -> b) -> Tag a -> LExpr f b -> LExpr' f (a -> b)
- LVar :: Tag a -> LExpr' f a
-
-liftExpr :: Traversable1 f => OccMap f -> PExpr f t -> LExpr f t
-liftExpr totals term =
- let (_, e) = liftExpr' totals term
- in e
-
-newtype FoundMap f = FoundMap
- (HashMap (SomeNameFor f) (Natural -- how many times seen
- ,Maybe (Some (LExpr f)))) -- the lifted subterm (once seen)
-
-instance Semigroup (FoundMap f) where
+pruneExpr' :: Traversable1 f => PHOASExpr typ Tag f t -> State (Natural, OccMap typ f) (PExpr typ f t)
+pruneExpr' = \case
+ orig@(PHOASOp ty args) -> do
+ let name = makeStableName' orig
+ seenBefore <- checkVisited name
+ if seenBefore
+ then pure $ PStub (NameFor name) ty
+ else POp (NameFor name) ty <$> traverse1 pruneExpr' args
+
+ orig@(PHOASLam tyf tyarg f) -> do
+ let name = makeStableName' orig
+ seenBefore <- checkVisited name
+ if seenBefore
+ then pure $ PStub (NameFor name) tyf
+ else do
+ tag <- state (\(i, mp) -> (Tag i, (i + 1, mp)))
+ let body = f (PHOASVar tyarg tag)
+ PLam (NameFor name) tyf tyarg tag <$> pruneExpr' body
+
+ PHOASVar ty tag -> pure $ PVar ty tag
+ where
+ checkVisited name = do
+ occmap <- gets snd
+ let (seenBefore, occmap') =
+ HM.alterF (\case Nothing -> (False, Just 1)
+ Just n -> (True, Just (n + 1)))
+ (SomeNameFor (NameFor name))
+ occmap
+ modify (second (const occmap'))
+ return seenBefore
+
+
+-- | Lifted expression: a bunch of to-be let bound expressions on top of an
+-- LExpr'. Because LExpr' is really just PExpr with the recursive positions
+-- replaced by LExpr, LExpr should be seen as PExpr with a bunch of to-be let
+-- bound expressions on top of every node.
+data LExpr typ f t = LExpr [Some (LExpr typ f)] (LExpr' typ f t)
+data LExpr' typ f t where -- TODO: this could be an instantiation of (a generalisation of) PExpr
+ LStub :: NameFor typ f t -> typ t -> LExpr' typ f t
+ LOp :: NameFor typ f t -> typ t -> f (LExpr typ f) t -> LExpr' typ f t
+ LLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> LExpr typ f b -> LExpr' typ f (a -> b)
+ LVar :: typ a -> Tag a -> LExpr' typ f a
+
+liftExpr :: Traversable1 f => OccMap typ f -> PExpr typ f t -> LExpr typ f t
+liftExpr totals term = snd (liftExpr' totals term)
+
+newtype FoundMap typ f = FoundMap
+ (HashMap (SomeNameFor typ f) (Natural -- how many times seen
+ ,Maybe (Some (LExpr typ f)))) -- the lifted subterm (once seen)
+
+instance Semigroup (FoundMap typ f) where
FoundMap m1 <> FoundMap m2 = FoundMap $
HM.unionWith (\(n1, me1) (n2, me2) -> (n1 + n2, me1 <|> me2)) m1 m2
-instance Monoid (FoundMap f) where
+instance Monoid (FoundMap typ f) where
mempty = FoundMap HM.empty
-liftExpr' :: Traversable1 f => OccMap f -> PExpr f t -> (FoundMap f, LExpr f t)
-liftExpr' _totals (PStub name) =
- (FoundMap $ HM.singleton (SomeNameFor name) (1, Just (Some (LExpr [] (LStub name))))
- ,LExpr [] (LStub name))
+liftExpr' :: Traversable1 f => OccMap typ f -> PExpr typ f t -> (FoundMap typ f, LExpr typ f t)
+liftExpr' _totals (PStub name ty) =
+ (FoundMap $ HM.singleton (SomeNameFor name) (1, Nothing) -- Just (Some (LExpr [] (LStub name)))
+ ,LExpr [] (LStub name ty))
-liftExpr' _totals (PVar tag) = (mempty, LExpr [] (LVar tag))
+liftExpr' _totals (PVar ty tag) = (mempty, LExpr [] (LVar ty tag))
liftExpr' totals term =
- let (FoundMap foundmap, name, term') = case term of
- POp n args -> let (fm, args') = traverse1 (liftExpr' totals) args
- in (fm, n, LOp n args')
- PLam n tag body -> let (fm, body') = liftExpr' totals body
- in (fm, n, LLam n tag body')
-
+ let (FoundMap foundmap, name, termty, term') = case term of
+ POp n ty args ->
+ let (fm, args') = traverse1 (liftExpr' totals) args
+ in (fm, n, ty, LOp n ty args')
+ PLam n tyf tyarg tag body ->
+ let (fm, body') = liftExpr' totals body
+ in (fm, n, tyf, LLam n tyf tyarg tag body')
+
+ -- TODO: perhaps this HM.toList together with the foldr HM.delete can be a single traversal of the HashMap
saturated = [case mterm of
Just t -> (nm, t)
Nothing -> error "Name saturated but no term found"
@@ -145,13 +177,60 @@ liftExpr' totals term =
in case HM.findWithDefault 0 (SomeNameFor name) totals of
1 -> (FoundMap foundmap', lterm)
tot | tot > 1 -> (FoundMap (HM.insert (SomeNameFor name) (1, Just (Some lterm)) foundmap')
- ,LExpr [] (LStub name))
+ ,LExpr [] (LStub name termty))
| otherwise -> error "Term does not exist, yet we have it in hand"
+-- | Errors on stubs.
+lexprTypeOf :: LExpr typ f t -> typ t
+lexprTypeOf (LExpr _ e) = case e of
+ LStub{} -> error "lexprTypeOf: got a stub"
+ LOp _ t _ -> t
+ LLam _ t _ _ _ -> t
+ LVar t _ -> t
+
+
-- TODO: lower LExpr into a normal expression with let bindings. Every LStub
-- should correspond to some let-bound expression higher up in the tree (if it
-- does not, that's a bug), and should become a De Bruijn variable reference to
-- said let-bound expression. Lambdas should also get proper De Bruijn indices
-- instead of tags, and LVar is also a normal variable (referring to a
-- lambda-abstracted argument).
+
+-- | Untyped De Bruijn expression. No more names: there are lets now, and
+-- variable references are De Bruijn indices. These indices are not type-safe
+-- yet, though.
+data UBExpr typ f t where
+ UBOp :: typ t -> f (UBExpr typ f) t -> UBExpr typ f t
+ UBLam :: typ a -> UBExpr typ f b -> UBExpr typ f (a -> b)
+ UBLet :: typ a -> UBExpr typ f a -> UBExpr typ f b -> UBExpr typ f b
+ -- De Bruijn index
+ UBVar :: typ t -> Int -> UBExpr typ f t
+
+lowerExpr :: LExpr typ f t -> UBExpr typ f t
+lowerExpr = lowerExpr' mempty 0
+
+-- 1. name |-> De Bruijn level of the variable defining that name
+-- 2. Number of variables already in scope
+lowerExpr' :: forall typ f t. Traversable1 f => HashMap (SomeNameFor typ f) Int -> Int -> LExpr typ f t -> UBExpr typ f t
+lowerExpr' namelvl curlvl (LExpr lifted ex) =
+ let prefix = buildPrefix curlvl lifted
+ curlvl' = curlvl + length lifted
+ in case ex of
+ LStub name ty ->
+ case HM.lookup (SomeNameFor name) namelvl of
+ Just lvl -> UBVar ty (curlvl - lvl - 1)
+ Nothing -> error "Variable out of scope"
+ LOp name ty args ->
+ UBOp ty (_ $ traverse1 _ args)
+ where
+ buildPrefix :: forall b. Int -> [Some (LExpr typ f)] -> UBExpr typ f b -> UBExpr typ f b
+ buildPrefix _ [] = id
+ buildPrefix lvl (Some rhs : rhss) =
+ UBLet (lexprTypeOf rhs) (lowerExpr' namelvl lvl rhs)
+ . buildPrefix (lvl + 1) rhss
+
+
+data Idx env t where
+ IZ :: Idx (t : env) t
+ IS :: Idx env t -> Idx (s : env) t