From 912d262c8aef92657b8991d05b7fe39dcb5b5fd4 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 28 Aug 2024 16:41:13 +0200 Subject: Move code to .Internal module, and some haddocs --- src/Data/Expr/SharingRecovery.hs | 428 +-------------------------- src/Data/Expr/SharingRecovery/Internal.hs | 475 ++++++++++++++++++++++++++++++ 2 files changed, 487 insertions(+), 416 deletions(-) create mode 100644 src/Data/Expr/SharingRecovery/Internal.hs (limited to 'src/Data/Expr') diff --git a/src/Data/Expr/SharingRecovery.hs b/src/Data/Expr/SharingRecovery.hs index 11a4709..02b3e3e 100644 --- a/src/Data/Expr/SharingRecovery.hs +++ b/src/Data/Expr/SharingRecovery.hs @@ -1,419 +1,15 @@ -{-# LANGUAGE BangPatterns #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DefaultSignatures #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeOperators #-} -module Data.Expr.SharingRecovery where +module Data.Expr.SharingRecovery ( + -- * Sharing recovery + sharingRecovery, -import Control.Applicative ((<|>)) -import Control.Monad.Trans.State.Strict -import Data.Bifunctor (first, second) -import Data.Char (chr, ord) -import Data.Functor.Const -import Data.Functor.Identity -import Data.Functor.Product -import Data.Hashable -import Data.HashMap.Strict (HashMap) -import qualified Data.HashMap.Strict as HM -import Data.List (sortBy, intersperse) -import Data.Maybe (fromMaybe) -import Data.Ord (comparing) -import Data.Some -import Data.Type.Equality -import GHC.StableName -import Numeric.Natural -import Unsafe.Coerce (unsafeCoerce) + -- * Expressions + PHOASExpr(..), + BExpr(..), + Idx(..), --- import Debug.Trace + -- * Traversing indexed structures + Functor1(..), + Traversable1(..), +) where -import Data.StableName.Extra - - --- TODO: This implementation needs extensive documentation. 1. It is written --- quite generically, meaning that the actual algorithm is easily obscured to --- all but the most willing readers; 2. the original paper leaves something to --- be desired in the domain of effective explanation for programmers, and this --- is a good opportunity to try to do better. - - -withMoreState :: Functor m => b -> StateT (s, b) m a -> StateT s m (a, b) -withMoreState b0 (StateT f) = - StateT $ \s -> (\(x, (s2, b)) -> ((x, b), s2)) <$> f (s, b0) - -withLessState :: Functor m => (s -> (s', b)) -> (s' -> b -> s) -> StateT s' m a -> StateT s m a -withLessState split restore (StateT f) = - StateT $ \s -> let (s', b) = split s - in second (flip restore b) <$> f s' - - -class Functor1 f where - fmap1 :: (forall b. g b -> h b) -> f g a -> f h a - - default fmap1 :: Traversable1 f => (forall b. g b -> h b) -> f g a -> f h a - fmap1 f x = runIdentity (traverse1 (Identity . f) x) - -class Functor1 f => Traversable1 f where - traverse1 :: Applicative m => (forall b. g b -> m (h b)) -> f g a -> m (f h a) - --- | Expression in parametric higher-order abstract syntax form -data PHOASExpr typ v f t where - PHOASOp :: typ t -> f (PHOASExpr typ v f) t -> PHOASExpr typ v f t - PHOASLam :: typ (a -> b) -> typ a -> (v a -> PHOASExpr typ v f b) -> PHOASExpr typ v f (a -> b) - PHOASVar :: typ t -> v t -> PHOASExpr typ v f t - -newtype Tag t = Tag Natural - deriving (Show, Eq) - deriving (Hashable) via Natural - -newtype NameFor typ f t = NameFor (StableName (PHOASExpr typ Tag f t)) - deriving (Eq) - deriving (Hashable) via (StableName (PHOASExpr typ Tag f t)) - -instance TestEquality (NameFor typ f) where - testEquality (NameFor n1) (NameFor n2) - | eqStableName n1 n2 = Just unsafeCoerceRefl - | otherwise = Nothing - where - unsafeCoerceRefl :: a :~: b -- restricted version of unsafeCoerce that only allows punting proofs - unsafeCoerceRefl = unsafeCoerce Refl - --- | Pruned expression. --- --- Note that variables do not, and will never, have a name: we don't bother --- detecting sharing for variable references, because that would only introduce --- a redundant variable indirection. -data PExpr typ f t where - PStub :: NameFor typ f t -> typ t -> PExpr typ f t - POp :: NameFor typ f t -> typ t -> f (PExpr typ f) t -> PExpr typ f t - PLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> PExpr typ f b -> PExpr typ f (a -> b) - PVar :: typ a -> Tag a -> PExpr typ f a - -data SomeNameFor typ f = forall t. SomeNameFor {-# UNPACK #-} !(NameFor typ f t) - -instance Eq (SomeNameFor typ f) where - SomeNameFor (NameFor n1) == SomeNameFor (NameFor n2) = eqStableName n1 n2 - -instance Hashable (SomeNameFor typ f) where - hashWithSalt salt (SomeNameFor name) = hashWithSalt salt name - -prettyPExpr :: Traversable1 f => Int -> PExpr typ f t -> ShowS -prettyPExpr d = \case - PStub (NameFor name) _ -> showString (showStableName name) - POp (NameFor name) _ args -> - let (argslist, _) = traverse1 (\arg -> ([Some arg], Const ())) args - argslist' = map (\(Some arg) -> prettyPExpr 0 arg) argslist - in showParen (d > 10) $ - showString ("<" ++ showStableName name ++ ">(") - . foldr (.) id (intersperse (showString ", ") argslist') - . showString ")" - PLam (NameFor name) _ _ (Tag tag) body -> - showParen (d > 0) $ - showString ("λ" ++ showStableName name ++ " x" ++ show tag ++ ". ") . prettyPExpr 0 body - PVar _ (Tag tag) -> showString ("x" ++ show tag) - --- | For each name: --- --- 1. The number of times the name is visited in a preorder traversal of the --- PHOAS expression, excluding children of nodes upon second or later visit. --- That is to say: only the nodes that are visited in a preorder traversal --- that skips repeated subtrees, are counted. --- 2. The height of the expression indicated by the name. --- --- Missing names have not been seen yet, and have unknown height. -type OccMap typ f = HashMap (SomeNameFor typ f) (Natural, Natural) - -pruneExpr :: Traversable1 f => (forall v. PHOASExpr typ v f t) -> (OccMap typ f, PExpr typ f t) -pruneExpr term = - let ((term', _), (_, mp)) = runState (pruneExpr' term) (0, mempty) - in (mp, term') - --- | Returns pruned expression with its height. -pruneExpr' :: Traversable1 f => PHOASExpr typ Tag f t -> State (Natural, OccMap typ f) (PExpr typ f t, Natural) -pruneExpr' = \case - orig@(PHOASOp ty args) -> do - let name = makeStableName' orig - mheight <- gets (fmap snd . HM.lookup (SomeNameFor (NameFor name)) . snd) - case mheight of - -- already visited - Just height -> do - modify (second (HM.adjust (first (+1)) (SomeNameFor (NameFor name)))) - pure (PStub (NameFor name) ty, height) - -- first visit - Nothing -> do - -- Traverse the arguments, collecting the maximum height in an - -- additional piece of state. - (args', maxhei) <- - withMoreState 0 $ - traverse1 (\arg -> do - (arg', hei) <- withLessState id (,) (pruneExpr' arg) - modify (second (hei `max`)) - return arg') - args - -- Record this node - modify (second (HM.insert (SomeNameFor (NameFor name)) (1, 1 + maxhei))) - pure (POp (NameFor name) ty args', 1 + maxhei) - - orig@(PHOASLam tyf tyarg f) -> do - let name = makeStableName' orig - mheight <- gets (fmap snd . HM.lookup (SomeNameFor (NameFor name)) . snd) - case mheight of - -- already visited - Just height -> do - modify (second (HM.adjust (first (+1)) (SomeNameFor (NameFor name)))) - pure (PStub (NameFor name) tyf, height) - -- first visit - Nothing -> do - tag <- Tag <$> gets fst - modify (first (+1)) - let body = f tag - (body', bodyhei) <- pruneExpr' body - modify (second (HM.insert (SomeNameFor (NameFor name)) (1, 1 + bodyhei))) - pure (PLam (NameFor name) tyf tyarg tag body', 1 + bodyhei) - - PHOASVar ty tag -> pure (PVar ty tag, 1) - - --- | Floated expression: a bunch of to-be let bound expressions on top of an --- LExpr'. Because LExpr' is really just PExpr with the recursive positions --- replaced by LExpr, LExpr should be seen as PExpr with a bunch of to-be let --- bound expressions on top of every node. -data LExpr typ f t = LExpr [Some (LExpr typ f)] (LExpr' typ f t) -data LExpr' typ f t where -- TODO: this could be an instantiation of (a generalisation of) PExpr - LStub :: NameFor typ f t -> typ t -> LExpr' typ f t - LOp :: NameFor typ f t -> typ t -> f (LExpr typ f) t -> LExpr' typ f t - LLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> LExpr typ f b -> LExpr' typ f (a -> b) - LVar :: typ a -> Tag a -> LExpr' typ f a - -prettyLExpr :: Traversable1 f => Int -> LExpr typ f t -> ShowS -prettyLExpr d (LExpr [] e) = prettyLExpr' d e -prettyLExpr d (LExpr floated e) = - showString "[" - . foldr (.) id (intersperse (showString ", ") (map (\(Some e') -> prettyLExpr 0 e') floated)) - . showString "] " . prettyLExpr' d e - -prettyLExpr' :: Traversable1 f => Int -> LExpr' typ f t -> ShowS -prettyLExpr' d = \case - LStub (NameFor name) _ -> showString (showStableName name) - LOp (NameFor name) _ args -> - let (argslist, _) = traverse1 (\arg -> ([Some arg], Const ())) args - argslist' = map (\(Some arg) -> prettyLExpr 0 arg) argslist - in showParen (d > 10) $ - showString ("<" ++ showStableName name ++ ">(") - . foldr (.) id (intersperse (showString ", ") argslist') - . showString ")" - LLam (NameFor name) _ _ (Tag tag) body -> - showParen (d > 0) $ - showString ("λ" ++ showStableName name ++ " x" ++ show tag ++ ". ") . prettyLExpr 0 body - LVar _ (Tag tag) -> showString ("x" ++ show tag) - -floatExpr :: Traversable1 f => OccMap typ f -> PExpr typ f t -> LExpr typ f t -floatExpr totals term = snd (floatExpr' totals term) - -newtype FoundMap typ f = FoundMap - (HashMap (SomeNameFor typ f) - (Natural -- how many times seen - ,Maybe (Some (LExpr typ f), Natural))) -- the floated subterm with its height (once seen) - -instance Semigroup (FoundMap typ f) where - FoundMap m1 <> FoundMap m2 = FoundMap $ - HM.unionWith (\(n1, me1) (n2, me2) -> (n1 + n2, me1 <|> me2)) m1 m2 - -instance Monoid (FoundMap typ f) where - mempty = FoundMap HM.empty - -floatExpr' :: Traversable1 f => OccMap typ f -> PExpr typ f t -> (FoundMap typ f, LExpr typ f t) -floatExpr' _totals (PStub name ty) = - -- trace ("Found stub: " ++ (case name of NameFor n -> showStableName n)) $ - (FoundMap $ HM.singleton (SomeNameFor name) (1, Nothing) - ,LExpr [] (LStub name ty)) - -floatExpr' _totals (PVar ty tag) = - -- trace ("Found var: " ++ show tag) $ - (mempty, LExpr [] (LVar ty tag)) - -floatExpr' totals term = - let (FoundMap foundmap, name, termty, term') = case term of - POp n ty args -> - let (fm, args') = traverse1 (floatExpr' totals) args - in (fm, n, ty, LOp n ty args') - PLam n tyf tyarg tag body -> - let (fm, body') = floatExpr' totals body - in (fm, n, tyf, LLam n tyf tyarg tag body') - - -- TODO: perhaps this HM.toList together with the foldr HM.delete can be a single traversal of the HashMap - saturated = [case mterm of - Just t -> (nm, t) - Nothing -> case nm of - SomeNameFor (NameFor n) -> - error $ "Name saturated (count=" ++ show count ++ ", totalcount=" ++ show totalcount ++ ") but no term found: " ++ showStableName n - | (nm, (count, mterm)) <- HM.toList foundmap - , let totalcount = fromMaybe 0 (fst <$> HM.lookup nm totals) - , count == totalcount] - - foundmap' = foldr HM.delete foundmap (map fst saturated) - - lterm = LExpr (map fst (sortBy (comparing snd) (map snd saturated))) term' - - in case HM.findWithDefault (0, undefined) (SomeNameFor name) totals of - (1, _) -> (FoundMap foundmap', lterm) - (tot, height) - | tot > 1 -> -- trace ("Inserting " ++ (case name of NameFor n -> showStableName n) ++ " into foundmap") $ - (FoundMap (HM.insert (SomeNameFor name) (1, Just (Some lterm, height)) foundmap') - ,LExpr [] (LStub name termty)) - | otherwise -> error "Term does not exist, yet we have it in hand" - - --- | Untyped De Bruijn expression. No more names: there are lets now, and --- variable references are De Bruijn indices. These indices are not type-safe --- yet, though. -data UBExpr typ f t where - UBOp :: typ t -> f (UBExpr typ f) t -> UBExpr typ f t - UBLam :: typ (a -> b) -> typ a -> UBExpr typ f b -> UBExpr typ f (a -> b) - UBLet :: typ a -> UBExpr typ f a -> UBExpr typ f b -> UBExpr typ f b - -- | De Bruijn index - UBVar :: typ t -> Int -> UBExpr typ f t - -lowerExpr :: Functor1 f => LExpr typ f t -> UBExpr typ f t -lowerExpr = lowerExpr' mempty mempty 0 - -data SomeTag = forall t. SomeTag (Tag t) - -instance Eq SomeTag where - SomeTag (Tag n) == SomeTag (Tag m) = n == m - -instance Hashable SomeTag where - hashWithSalt salt (SomeTag tag) = hashWithSalt salt tag - -lowerExpr' :: forall typ f t. Functor1 f - => HashMap (SomeNameFor typ f) Int -- ^ node |-> De Bruijn level of defining binding - -> HashMap SomeTag Int -- ^ tag |-> De Bruijn level of defining binding - -> Int -- ^ Number of variables already in scope - -> LExpr typ f t -> UBExpr typ f t -lowerExpr' namelvl taglvl curlvl (LExpr floated ex) = - let (namelvl', prefix) = buildPrefix namelvl curlvl floated - curlvl' = curlvl + length floated - in prefix $ - case ex of - LStub name ty -> - case HM.lookup (SomeNameFor name) namelvl' of - Just lvl -> UBVar ty (curlvl - lvl - 1) - Nothing -> error "Name variable out of scope" - LOp _ ty args -> - UBOp ty (fmap1 (lowerExpr' namelvl' taglvl curlvl') args) - LLam _ tyf tyarg tag body -> - UBLam tyf tyarg (lowerExpr' namelvl' (HM.insert (SomeTag tag) curlvl' taglvl) (curlvl' + 1) body) - LVar ty tag -> - case HM.lookup (SomeTag tag) taglvl of - Just lvl -> UBVar ty (curlvl - lvl - 1) - Nothing -> error "Tag variable out of scope" - where - buildPrefix :: forall b. - HashMap (SomeNameFor typ f) Int - -> Int - -> [Some (LExpr typ f)] - -> (HashMap (SomeNameFor typ f) Int, UBExpr typ f b -> UBExpr typ f b) - buildPrefix namelvl' _ [] = (namelvl', id) - buildPrefix namelvl' lvl (Some rhs@(LExpr _ rhs') : rhss) = - let name = case rhs' of - LStub n _ -> n - LOp n _ _ -> n - LLam n _ _ _ _ -> n - LVar _ _ -> error "Recovering sharing of a tag is useless" - ty = case rhs' of - LStub{} -> error "Recovering sharing of a stub is useless" - LOp _ t _ -> t - LLam _ t _ _ _ -> t - LVar t _ -> t - prefix = UBLet ty (lowerExpr' namelvl' taglvl lvl rhs) - in (prefix .) <$> buildPrefix (HM.insert (SomeNameFor name) lvl namelvl') (lvl + 1) rhss - - --- | A typed De Bruijn index. -data Idx env t where - IZ :: Idx (t : env) t - IS :: Idx env t -> Idx (s : env) t -deriving instance Show (Idx env t) - -data Env env f where - ETop :: Env '[] f - EPush :: Env env f -> f t -> Env (t : env) f - -envLookup :: Idx env t -> Env env f -> f t -envLookup IZ (EPush _ x) = x -envLookup (IS i) (EPush e _) = envLookup i e - --- | Untyped lookup in an 'Env'. -envLookupU :: Int -> Env env f -> Maybe (Some (Product f (Idx env))) -envLookupU = go id - where - go :: (forall a. Idx env a -> Idx env' a) -> Int -> Env env f -> Maybe (Some (Product f (Idx env'))) - go !_ !_ ETop = Nothing - go f 0 (EPush _ t) = Just (Some (Pair t (f IZ))) - go f i (EPush e _) = go (f . IS) (i - 1) e - --- | Typed De Bruijn expression. This is the final result of sharing recovery. -data BExpr typ env f t where - BOp :: typ t -> f (BExpr typ env f) t -> BExpr typ env f t - BLam :: typ (a -> b) -> typ a -> BExpr typ (a : env) f b -> BExpr typ env f (a -> b) - BLet :: typ a -> BExpr typ env f a -> BExpr typ (a : env) f b -> BExpr typ env f b - BVar :: typ t -> Idx env t -> BExpr typ env f t -deriving instance (forall a. Show (typ a), forall a r. (forall b. Show (r b)) => Show (f r a)) - => Show (BExpr typ env f t) - -prettyBExpr :: (forall m env' a. Monad m => (forall b. Int -> BExpr typ env' f b -> m ShowS) - -> Int -> f (BExpr typ env' f) a -> m ShowS) - -> BExpr typ '[] f t -> String -prettyBExpr prettyOp e = evalState (prettyBExpr' prettyOp ETop 0 e) 0 "" - -prettyBExpr' :: (forall m env' a. Monad m => (forall b. Int -> BExpr typ env' f b -> m ShowS) - -> Int -> f (BExpr typ env' f) a -> m ShowS) - -> Env env (Const String) -> Int -> BExpr typ env f t -> State Int ShowS -prettyBExpr' prettyOp env d = \case - BOp _ args -> - prettyOp (prettyBExpr' prettyOp env) d args - BLam _ _ body -> do - name <- genName - body' <- prettyBExpr' prettyOp (EPush env (Const name)) 0 body - return $ showParen (d > 0) $ showString ("λ" ++ name ++ ". ") . body' - BLet _ rhs body -> do - name <- genName - rhs' <- prettyBExpr' prettyOp env 0 rhs - body' <- prettyBExpr' prettyOp (EPush env (Const name)) 0 body - return $ showParen (d > 0) $ showString ("let " ++ name ++ " = ") . rhs' . showString " in " . body' - BVar _ idx -> - return $ showString (getConst (envLookup idx env)) - where - genName = do - i <- state (\i -> (i, i + 1)) - return $ if i < 26 then [chr (ord 'a' + i)] else 'x' : show i - -retypeExpr :: (Functor1 f, TestEquality typ) => UBExpr typ f t -> BExpr typ '[] f t -retypeExpr = retypeExpr' ETop - -retypeExpr' :: (Functor1 f, TestEquality typ) => Env env typ -> UBExpr typ f t -> BExpr typ env f t -retypeExpr' env (UBOp ty args) = BOp ty (fmap1 (retypeExpr' env) args) -retypeExpr' env (UBLam tyf tyarg body) = BLam tyf tyarg (retypeExpr' (EPush env tyarg) body) -retypeExpr' env (UBLet ty rhs body) = BLet ty (retypeExpr' env rhs) (retypeExpr' (EPush env ty) body) -retypeExpr' env (UBVar ty idx) = - case envLookupU idx env of - Just (Some (Pair defty tidx)) -> - case testEquality ty defty of - Just Refl -> BVar ty tidx - Nothing -> error "Type mismatch in untyped De Bruijn expression" - Nothing -> error "Untyped De Bruijn index out of range" - - -sharingRecovery :: (Traversable1 f, TestEquality typ) => (forall v. PHOASExpr typ v f t) -> BExpr typ '[] f t -sharingRecovery e = - let (occmap, pexpr) = pruneExpr e - lexpr = floatExpr occmap pexpr - ubexpr = lowerExpr lexpr - in -- trace ("PExpr: " ++ prettyPExpr 0 pexpr "") $ - -- trace ("LExpr: " ++ prettyLExpr 0 lexpr "") $ - retypeExpr ubexpr +import Data.Expr.SharingRecovery.Internal diff --git a/src/Data/Expr/SharingRecovery/Internal.hs b/src/Data/Expr/SharingRecovery/Internal.hs new file mode 100644 index 0000000..0089454 --- /dev/null +++ b/src/Data/Expr/SharingRecovery/Internal.hs @@ -0,0 +1,475 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeOperators #-} +module Data.Expr.SharingRecovery.Internal where + +import Control.Applicative ((<|>)) +import Control.Monad.Trans.State.Strict +import Data.Bifunctor (first, second) +import Data.Char (chr, ord) +import Data.Functor.Const +import Data.Functor.Identity +import Data.Functor.Product +import Data.Hashable +import Data.HashMap.Strict (HashMap) +import qualified Data.HashMap.Strict as HM +import Data.List (sortBy, intersperse) +import Data.Maybe (fromMaybe) +import Data.Ord (comparing) +import Data.Some +import Data.Type.Equality +import GHC.StableName +import Numeric.Natural +import Unsafe.Coerce (unsafeCoerce) + +-- import Debug.Trace + +import Data.StableName.Extra + + +-- TODO: This implementation needs extensive documentation. 1. It is written +-- quite generically, meaning that the actual algorithm is easily obscured to +-- all but the most willing readers; 2. the original paper leaves something to +-- be desired in the domain of effective explanation for programmers, and this +-- is a good opportunity to try to do better. + + +withMoreState :: Functor m => b -> StateT (s, b) m a -> StateT s m (a, b) +withMoreState b0 (StateT f) = + StateT $ \s -> (\(x, (s2, b)) -> ((x, b), s2)) <$> f (s, b0) + +withLessState :: Functor m => (s -> (s', b)) -> (s' -> b -> s) -> StateT s' m a -> StateT s m a +withLessState split restore (StateT f) = + StateT $ \s -> let (s', b) = split s + in second (flip restore b) <$> f s' + + +-- | 'Functor' on the second-to-last type parameter. +class Functor1 f where + fmap1 :: (forall b. g b -> h b) -> f g a -> f h a + + default fmap1 :: Traversable1 f => (forall b. g b -> h b) -> f g a -> f h a + fmap1 f x = runIdentity (traverse1 (Identity . f) x) + +-- | 'Traversable' on the second-to-last type parameter. +class Functor1 f => Traversable1 f where + traverse1 :: Applicative m => (forall b. g b -> m (h b)) -> f g a -> m (f h a) + +-- | Expression in parametric higher-order abstract syntax form. +-- +-- * @typ@ should be a singleton GADT that describes the @t@ type parameter. It +-- should implement 'TestEquality'. For example, for a simple language with +-- only @Int@, pairs and functions as types, the following would suffice: +-- +-- @ +-- data Typ t where +-- TInt :: Typ Int +-- TPair :: Typ a -> Typ b -> Typ (a, b) +-- TFun :: Typ a -> Typ b -> Typ (a -> b) +-- @ +-- +-- * @v@ is the type of variables in the expression. A PHOAS expression is +-- required to be parametric in the @v@ parameter; the only place you will +-- obtain a @v@ is inside a @PHOASLam@ function body. +-- +-- * @f@ should be your type of operations for your language. It is indexed by +-- the type of subexpressions and the result type of the operation; thus, it +-- is a "base functor" indexed by one additional parameter (@t@). For +-- example, for a simple language that supports only integer constants, +-- integer addition, lambda abstraction and function application: +-- +-- @ +-- data Oper r t where +-- OConst :: Int -> Oper r Int +-- OAdd :: r Int -> r Int -> Oper r Int +-- OApp :: r (a -> b) -> r a -> Oper r b +-- @ +-- +-- Note that lambda abstraction is not an operation, because 'PHOASExpr' +-- already represents lambda abstraction as 'PHOASLam'. The reason lambdas +-- are part of 'PHOASExpr' is that 'sharingRecovery' must be able to inspect +-- lambdas and analyse their bodies. +-- +-- Note furthermore that @Oper@ is /not/ a recursive type. Subexpressions +-- are again 'PHOASExpr's, and 'sharingRecovery' needs to be able to see +-- them. Hence, you should call back to back to @r@ instead of recursing +-- manually. +-- +-- * @t@ is the result type of this expression. +data PHOASExpr typ v f t where + PHOASOp :: typ t -> f (PHOASExpr typ v f) t -> PHOASExpr typ v f t + PHOASLam :: typ (a -> b) -> typ a -> (v a -> PHOASExpr typ v f b) -> PHOASExpr typ v f (a -> b) + PHOASVar :: typ t -> v t -> PHOASExpr typ v f t + +newtype Tag t = Tag Natural + deriving (Show, Eq) + deriving (Hashable) via Natural + +newtype NameFor typ f t = NameFor (StableName (PHOASExpr typ Tag f t)) + deriving (Eq) + deriving (Hashable) via (StableName (PHOASExpr typ Tag f t)) + +instance TestEquality (NameFor typ f) where + testEquality (NameFor n1) (NameFor n2) + | eqStableName n1 n2 = Just unsafeCoerceRefl + | otherwise = Nothing + where + unsafeCoerceRefl :: a :~: b -- restricted version of unsafeCoerce that only allows punting proofs + unsafeCoerceRefl = unsafeCoerce Refl + +-- | Pruned expression. +-- +-- Note that variables do not, and will never, have a name: we don't bother +-- detecting sharing for variable references, because that would only introduce +-- a redundant variable indirection. +data PExpr typ f t where + PStub :: NameFor typ f t -> typ t -> PExpr typ f t + POp :: NameFor typ f t -> typ t -> f (PExpr typ f) t -> PExpr typ f t + PLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> PExpr typ f b -> PExpr typ f (a -> b) + PVar :: typ a -> Tag a -> PExpr typ f a + +data SomeNameFor typ f = forall t. SomeNameFor {-# UNPACK #-} !(NameFor typ f t) + +instance Eq (SomeNameFor typ f) where + SomeNameFor (NameFor n1) == SomeNameFor (NameFor n2) = eqStableName n1 n2 + +instance Hashable (SomeNameFor typ f) where + hashWithSalt salt (SomeNameFor name) = hashWithSalt salt name + +prettyPExpr :: Traversable1 f => Int -> PExpr typ f t -> ShowS +prettyPExpr d = \case + PStub (NameFor name) _ -> showString (showStableName name) + POp (NameFor name) _ args -> + let (argslist, _) = traverse1 (\arg -> ([Some arg], Const ())) args + argslist' = map (\(Some arg) -> prettyPExpr 0 arg) argslist + in showParen (d > 10) $ + showString ("<" ++ showStableName name ++ ">(") + . foldr (.) id (intersperse (showString ", ") argslist') + . showString ")" + PLam (NameFor name) _ _ (Tag tag) body -> + showParen (d > 0) $ + showString ("λ" ++ showStableName name ++ " x" ++ show tag ++ ". ") . prettyPExpr 0 body + PVar _ (Tag tag) -> showString ("x" ++ show tag) + +-- | For each name: +-- +-- 1. The number of times the name is visited in a preorder traversal of the +-- PHOAS expression, excluding children of nodes upon second or later visit. +-- That is to say: only the nodes that are visited in a preorder traversal +-- that skips repeated subtrees, are counted. +-- 2. The height of the expression indicated by the name. +-- +-- Missing names have not been seen yet, and have unknown height. +type OccMap typ f = HashMap (SomeNameFor typ f) (Natural, Natural) + +pruneExpr :: Traversable1 f => (forall v. PHOASExpr typ v f t) -> (OccMap typ f, PExpr typ f t) +pruneExpr term = + let ((term', _), (_, mp)) = runState (pruneExpr' term) (0, mempty) + in (mp, term') + +-- | Returns pruned expression with its height. +pruneExpr' :: Traversable1 f => PHOASExpr typ Tag f t -> State (Natural, OccMap typ f) (PExpr typ f t, Natural) +pruneExpr' = \case + orig@(PHOASOp ty args) -> do + let name = makeStableName' orig + mheight <- gets (fmap snd . HM.lookup (SomeNameFor (NameFor name)) . snd) + case mheight of + -- already visited + Just height -> do + modify (second (HM.adjust (first (+1)) (SomeNameFor (NameFor name)))) + pure (PStub (NameFor name) ty, height) + -- first visit + Nothing -> do + -- Traverse the arguments, collecting the maximum height in an + -- additional piece of state. + (args', maxhei) <- + withMoreState 0 $ + traverse1 (\arg -> do + (arg', hei) <- withLessState id (,) (pruneExpr' arg) + modify (second (hei `max`)) + return arg') + args + -- Record this node + modify (second (HM.insert (SomeNameFor (NameFor name)) (1, 1 + maxhei))) + pure (POp (NameFor name) ty args', 1 + maxhei) + + orig@(PHOASLam tyf tyarg f) -> do + let name = makeStableName' orig + mheight <- gets (fmap snd . HM.lookup (SomeNameFor (NameFor name)) . snd) + case mheight of + -- already visited + Just height -> do + modify (second (HM.adjust (first (+1)) (SomeNameFor (NameFor name)))) + pure (PStub (NameFor name) tyf, height) + -- first visit + Nothing -> do + tag <- Tag <$> gets fst + modify (first (+1)) + let body = f tag + (body', bodyhei) <- pruneExpr' body + modify (second (HM.insert (SomeNameFor (NameFor name)) (1, 1 + bodyhei))) + pure (PLam (NameFor name) tyf tyarg tag body', 1 + bodyhei) + + PHOASVar ty tag -> pure (PVar ty tag, 1) + + +-- | Floated expression: a bunch of to-be let bound expressions on top of an +-- LExpr'. Because LExpr' is really just PExpr with the recursive positions +-- replaced by LExpr, LExpr should be seen as PExpr with a bunch of to-be let +-- bound expressions on top of every node. +data LExpr typ f t = LExpr [Some (LExpr typ f)] (LExpr' typ f t) +data LExpr' typ f t where -- TODO: this could be an instantiation of (a generalisation of) PExpr + LStub :: NameFor typ f t -> typ t -> LExpr' typ f t + LOp :: NameFor typ f t -> typ t -> f (LExpr typ f) t -> LExpr' typ f t + LLam :: NameFor typ f (a -> b) -> typ (a -> b) -> typ a -> Tag a -> LExpr typ f b -> LExpr' typ f (a -> b) + LVar :: typ a -> Tag a -> LExpr' typ f a + +prettyLExpr :: Traversable1 f => Int -> LExpr typ f t -> ShowS +prettyLExpr d (LExpr [] e) = prettyLExpr' d e +prettyLExpr d (LExpr floated e) = + showString "[" + . foldr (.) id (intersperse (showString ", ") (map (\(Some e') -> prettyLExpr 0 e') floated)) + . showString "] " . prettyLExpr' d e + +prettyLExpr' :: Traversable1 f => Int -> LExpr' typ f t -> ShowS +prettyLExpr' d = \case + LStub (NameFor name) _ -> showString (showStableName name) + LOp (NameFor name) _ args -> + let (argslist, _) = traverse1 (\arg -> ([Some arg], Const ())) args + argslist' = map (\(Some arg) -> prettyLExpr 0 arg) argslist + in showParen (d > 10) $ + showString ("<" ++ showStableName name ++ ">(") + . foldr (.) id (intersperse (showString ", ") argslist') + . showString ")" + LLam (NameFor name) _ _ (Tag tag) body -> + showParen (d > 0) $ + showString ("λ" ++ showStableName name ++ " x" ++ show tag ++ ". ") . prettyLExpr 0 body + LVar _ (Tag tag) -> showString ("x" ++ show tag) + +floatExpr :: Traversable1 f => OccMap typ f -> PExpr typ f t -> LExpr typ f t +floatExpr totals term = snd (floatExpr' totals term) + +newtype FoundMap typ f = FoundMap + (HashMap (SomeNameFor typ f) + (Natural -- how many times seen + ,Maybe (Some (LExpr typ f), Natural))) -- the floated subterm with its height (once seen) + +instance Semigroup (FoundMap typ f) where + FoundMap m1 <> FoundMap m2 = FoundMap $ + HM.unionWith (\(n1, me1) (n2, me2) -> (n1 + n2, me1 <|> me2)) m1 m2 + +instance Monoid (FoundMap typ f) where + mempty = FoundMap HM.empty + +floatExpr' :: Traversable1 f => OccMap typ f -> PExpr typ f t -> (FoundMap typ f, LExpr typ f t) +floatExpr' _totals (PStub name ty) = + -- trace ("Found stub: " ++ (case name of NameFor n -> showStableName n)) $ + (FoundMap $ HM.singleton (SomeNameFor name) (1, Nothing) + ,LExpr [] (LStub name ty)) + +floatExpr' _totals (PVar ty tag) = + -- trace ("Found var: " ++ show tag) $ + (mempty, LExpr [] (LVar ty tag)) + +floatExpr' totals term = + let (FoundMap foundmap, name, termty, term') = case term of + POp n ty args -> + let (fm, args') = traverse1 (floatExpr' totals) args + in (fm, n, ty, LOp n ty args') + PLam n tyf tyarg tag body -> + let (fm, body') = floatExpr' totals body + in (fm, n, tyf, LLam n tyf tyarg tag body') + + -- TODO: perhaps this HM.toList together with the foldr HM.delete can be a single traversal of the HashMap + saturated = [case mterm of + Just t -> (nm, t) + Nothing -> case nm of + SomeNameFor (NameFor n) -> + error $ "Name saturated (count=" ++ show count ++ ", totalcount=" ++ show totalcount ++ ") but no term found: " ++ showStableName n + | (nm, (count, mterm)) <- HM.toList foundmap + , let totalcount = fromMaybe 0 (fst <$> HM.lookup nm totals) + , count == totalcount] + + foundmap' = foldr HM.delete foundmap (map fst saturated) + + lterm = LExpr (map fst (sortBy (comparing snd) (map snd saturated))) term' + + in case HM.findWithDefault (0, undefined) (SomeNameFor name) totals of + (1, _) -> (FoundMap foundmap', lterm) + (tot, height) + | tot > 1 -> -- trace ("Inserting " ++ (case name of NameFor n -> showStableName n) ++ " into foundmap") $ + (FoundMap (HM.insert (SomeNameFor name) (1, Just (Some lterm, height)) foundmap') + ,LExpr [] (LStub name termty)) + | otherwise -> error "Term does not exist, yet we have it in hand" + + +-- | Untyped De Bruijn expression. No more names: there are lets now, and +-- variable references are De Bruijn indices. These indices are not type-safe +-- yet, though. +data UBExpr typ f t where + UBOp :: typ t -> f (UBExpr typ f) t -> UBExpr typ f t + UBLam :: typ (a -> b) -> typ a -> UBExpr typ f b -> UBExpr typ f (a -> b) + UBLet :: typ a -> UBExpr typ f a -> UBExpr typ f b -> UBExpr typ f b + -- | De Bruijn index + UBVar :: typ t -> Int -> UBExpr typ f t + +lowerExpr :: Functor1 f => LExpr typ f t -> UBExpr typ f t +lowerExpr = lowerExpr' mempty mempty 0 + +data SomeTag = forall t. SomeTag (Tag t) + +instance Eq SomeTag where + SomeTag (Tag n) == SomeTag (Tag m) = n == m + +instance Hashable SomeTag where + hashWithSalt salt (SomeTag tag) = hashWithSalt salt tag + +lowerExpr' :: forall typ f t. Functor1 f + => HashMap (SomeNameFor typ f) Int -- ^ node |-> De Bruijn level of defining binding + -> HashMap SomeTag Int -- ^ tag |-> De Bruijn level of defining binding + -> Int -- ^ Number of variables already in scope + -> LExpr typ f t -> UBExpr typ f t +lowerExpr' namelvl taglvl curlvl (LExpr floated ex) = + let (namelvl', prefix) = buildPrefix namelvl curlvl floated + curlvl' = curlvl + length floated + in prefix $ + case ex of + LStub name ty -> + case HM.lookup (SomeNameFor name) namelvl' of + Just lvl -> UBVar ty (curlvl - lvl - 1) + Nothing -> error "Name variable out of scope" + LOp _ ty args -> + UBOp ty (fmap1 (lowerExpr' namelvl' taglvl curlvl') args) + LLam _ tyf tyarg tag body -> + UBLam tyf tyarg (lowerExpr' namelvl' (HM.insert (SomeTag tag) curlvl' taglvl) (curlvl' + 1) body) + LVar ty tag -> + case HM.lookup (SomeTag tag) taglvl of + Just lvl -> UBVar ty (curlvl - lvl - 1) + Nothing -> error "Tag variable out of scope" + where + buildPrefix :: forall b. + HashMap (SomeNameFor typ f) Int + -> Int + -> [Some (LExpr typ f)] + -> (HashMap (SomeNameFor typ f) Int, UBExpr typ f b -> UBExpr typ f b) + buildPrefix namelvl' _ [] = (namelvl', id) + buildPrefix namelvl' lvl (Some rhs@(LExpr _ rhs') : rhss) = + let name = case rhs' of + LStub n _ -> n + LOp n _ _ -> n + LLam n _ _ _ _ -> n + LVar _ _ -> error "Recovering sharing of a tag is useless" + ty = case rhs' of + LStub{} -> error "Recovering sharing of a stub is useless" + LOp _ t _ -> t + LLam _ t _ _ _ -> t + LVar t _ -> t + prefix = UBLet ty (lowerExpr' namelvl' taglvl lvl rhs) + in (prefix .) <$> buildPrefix (HM.insert (SomeNameFor name) lvl namelvl') (lvl + 1) rhss + + +-- | A typed De Bruijn index. +data Idx env t where + IZ :: Idx (t : env) t + IS :: Idx env t -> Idx (s : env) t +deriving instance Show (Idx env t) + +data Env env f where + ETop :: Env '[] f + EPush :: Env env f -> f t -> Env (t : env) f + +envLookup :: Idx env t -> Env env f -> f t +envLookup IZ (EPush _ x) = x +envLookup (IS i) (EPush e _) = envLookup i e + +-- | Untyped lookup in an 'Env'. +envLookupU :: Int -> Env env f -> Maybe (Some (Product f (Idx env))) +envLookupU = go id + where + go :: (forall a. Idx env a -> Idx env' a) -> Int -> Env env f -> Maybe (Some (Product f (Idx env'))) + go !_ !_ ETop = Nothing + go f 0 (EPush _ t) = Just (Some (Pair t (f IZ))) + go f i (EPush e _) = go (f . IS) (i - 1) e + +-- | Typed De Bruijn expression. This is the resu,t of sharing recovery. It is +-- not higher-order any more, and furthermore has explicit let-bindings ('BLet') +-- that denote the sharing inside the term. This is a normal AST. +-- +-- * @env@ is a type-level list containing the types of all variables in scope +-- in the expression. The bottom-most variable (i.e. the one defined most +-- recently) is at the head of the list. 'Idx' is a De Bruijn index that +-- indexes into this list, to ensure that the whole expression is well-typed +-- and well-scoped. +-- +-- * @typ@, @f@ and @t@ are exactly as in 'PHOASExpr'. +data BExpr typ env f t where + BOp :: typ t -> f (BExpr typ env f) t -> BExpr typ env f t + BLam :: typ (a -> b) -> typ a -> BExpr typ (a : env) f b -> BExpr typ env f (a -> b) + BLet :: typ a -> BExpr typ env f a -> BExpr typ (a : env) f b -> BExpr typ env f b + BVar :: typ t -> Idx env t -> BExpr typ env f t +deriving instance (forall a. Show (typ a), forall a r. (forall b. Show (r b)) => Show (f r a)) + => Show (BExpr typ env f t) + +prettyBExpr :: (forall m env' a. Monad m => (forall b. Int -> BExpr typ env' f b -> m ShowS) + -> Int -> f (BExpr typ env' f) a -> m ShowS) + -> BExpr typ '[] f t -> String +prettyBExpr prettyOp e = evalState (prettyBExpr' prettyOp ETop 0 e) 0 "" + +prettyBExpr' :: (forall m env' a. Monad m => (forall b. Int -> BExpr typ env' f b -> m ShowS) + -> Int -> f (BExpr typ env' f) a -> m ShowS) + -> Env env (Const String) -> Int -> BExpr typ env f t -> State Int ShowS +prettyBExpr' prettyOp env d = \case + BOp _ args -> + prettyOp (prettyBExpr' prettyOp env) d args + BLam _ _ body -> do + name <- genName + body' <- prettyBExpr' prettyOp (EPush env (Const name)) 0 body + return $ showParen (d > 0) $ showString ("λ" ++ name ++ ". ") . body' + BLet _ rhs body -> do + name <- genName + rhs' <- prettyBExpr' prettyOp env 0 rhs + body' <- prettyBExpr' prettyOp (EPush env (Const name)) 0 body + return $ showParen (d > 0) $ showString ("let " ++ name ++ " = ") . rhs' . showString " in " . body' + BVar _ idx -> + return $ showString (getConst (envLookup idx env)) + where + genName = do + i <- state (\i -> (i, i + 1)) + return $ if i < 26 then [chr (ord 'a' + i)] else 'x' : show i + +retypeExpr :: (Functor1 f, TestEquality typ) => UBExpr typ f t -> BExpr typ '[] f t +retypeExpr = retypeExpr' ETop + +retypeExpr' :: (Functor1 f, TestEquality typ) => Env env typ -> UBExpr typ f t -> BExpr typ env f t +retypeExpr' env (UBOp ty args) = BOp ty (fmap1 (retypeExpr' env) args) +retypeExpr' env (UBLam tyf tyarg body) = BLam tyf tyarg (retypeExpr' (EPush env tyarg) body) +retypeExpr' env (UBLet ty rhs body) = BLet ty (retypeExpr' env rhs) (retypeExpr' (EPush env ty) body) +retypeExpr' env (UBVar ty idx) = + case envLookupU idx env of + Just (Some (Pair defty tidx)) -> + case testEquality ty defty of + Just Refl -> BVar ty tidx + Nothing -> error "Type mismatch in untyped De Bruijn expression" + Nothing -> error "Untyped De Bruijn index out of range" + + +-- | By observing internal sharing using 'StableName's (in +-- 'System.IO.Unsafe.unsafePerformIO'), convert an expression in higher-order +-- abstract syntax form to a well-typed well-scoped De Bruijn expression with +-- explicit let-bindings. +sharingRecovery :: (Traversable1 f, TestEquality typ) => (forall v. PHOASExpr typ v f t) -> BExpr typ '[] f t +sharingRecovery e = + let (occmap, pexpr) = pruneExpr e + lexpr = floatExpr occmap pexpr + ubexpr = lowerExpr lexpr + in -- trace ("PExpr: " ++ prettyPExpr 0 pexpr "") $ + -- trace ("LExpr: " ++ prettyLExpr 0 lexpr "") $ + retypeExpr ubexpr -- cgit v1.2.3-70-g09d2