From 4772025626d78127536c341c38052d23ca953ae3 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Fri, 3 Oct 2025 23:05:24 +0200 Subject: Move TH experiments to test-th --- sharing-recovery.cabal | 4 +- test-th/Arith.hs | 103 ++++++++++++++++++++++ test-th/Arith/NonBase.hs | 50 +++++++++++ test-th/Main.hs | 51 +++++++++++ test-th/NonBaseTH.hs | 225 +++++++++++++++++++++++++++++++++++++++++++++++ test/Arith.hs | 103 ---------------------- test/Arith/NonBase.hs | 50 ----------- test/Main.hs | 51 ----------- test/NonBaseTH.hs | 225 ----------------------------------------------- 9 files changed, 431 insertions(+), 431 deletions(-) create mode 100644 test-th/Arith.hs create mode 100644 test-th/Arith/NonBase.hs create mode 100644 test-th/Main.hs create mode 100644 test-th/NonBaseTH.hs delete mode 100644 test/Arith.hs delete mode 100644 test/Arith/NonBase.hs delete mode 100644 test/Main.hs delete mode 100644 test/NonBaseTH.hs diff --git a/sharing-recovery.cabal b/sharing-recovery.cabal index 9e34bab..42f233c 100644 --- a/sharing-recovery.cabal +++ b/sharing-recovery.cabal @@ -21,14 +21,14 @@ library default-language: Haskell2010 ghc-options: -Wall -test-suite test +test-suite test-th type: exitcode-stdio-1.0 main-is: Main.hs other-modules: Arith Arith.NonBase NonBaseTH - hs-source-dirs: test + hs-source-dirs: test-th build-depends: sharing-recovery, base, diff --git a/test-th/Arith.hs b/test-th/Arith.hs new file mode 100644 index 0000000..c34baa8 --- /dev/null +++ b/test-th/Arith.hs @@ -0,0 +1,103 @@ +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE QuantifiedConstraints #-} +module Arith where + +import Data.Type.Equality + +import Data.Expr.SharingRecovery + + +data Typ t where + TInt :: Typ Int + TBool :: Typ Bool + TPair :: Typ a -> Typ b -> Typ (a, b) + TFun :: Typ a -> Typ b -> Typ (a -> b) +deriving instance Show (Typ t) + +instance TestEquality Typ where + testEquality TInt TInt = Just Refl + testEquality TBool TBool = Just Refl + testEquality (TPair a b) (TPair a' b') + | Just Refl <- testEquality a a' + , Just Refl <- testEquality b b' + = Just Refl + testEquality (TFun a b) (TFun a' b') + | Just Refl <- testEquality a a' + , Just Refl <- testEquality b b' + = Just Refl + testEquality _ _ = Nothing + +class KnownType t where τ :: Typ t +instance KnownType Int where τ = TInt +instance KnownType Bool where τ = TBool +instance (KnownType a, KnownType b) => KnownType (a, b) where τ = TPair τ τ +instance (KnownType a, KnownType b) => KnownType (a -> b) where τ = TFun τ τ + +data PrimOp a b where + POAddI :: PrimOp (Int, Int) Int + POMulI :: PrimOp (Int, Int) Int + POEqI :: PrimOp (Int, Int) Bool +deriving instance Show (PrimOp a b) + +opType2 :: PrimOp a b -> Typ b +opType2 = \case + POAddI -> TInt + POMulI -> TInt + POEqI -> TBool + +data Fixity = Infix | Prefix + deriving (Show) + +primOpPrec :: PrimOp a b -> (Int, (Int, Int)) +primOpPrec POAddI = (6, (6, 7)) +primOpPrec POMulI = (7, (7, 8)) +primOpPrec POEqI = (4, (5, 5)) + +prettyPrimOp :: Fixity -> PrimOp a b -> ShowS +prettyPrimOp fix op = + let s = case op of + POAddI -> "+" + POMulI -> "*" + POEqI -> "==" + in showString $ case fix of + Infix -> s + Prefix -> "(" ++ s ++ ")" + +data ArithF r t where + A_Prim :: PrimOp a b -> r a -> ArithF r b + A_Pair :: r a -> r b -> ArithF r (a, b) + A_If :: r Bool -> r a -> r a -> ArithF r a +deriving instance (forall a. Show (r a)) => Show (ArithF r t) + +instance Functor1 ArithF +instance Traversable1 ArithF where + traverse1 f (A_Prim op x) = A_Prim op <$> f x + traverse1 f (A_Pair x y) = A_Pair <$> f x <*> f y + traverse1 f (A_If x y z) = A_If <$> f x <*> f y <*> f z + +prettyArithF :: Monad m + => (forall a. Int -> BExpr Typ env ArithF a -> m ShowS) + -> Int -> ArithF (BExpr Typ env ArithF) t -> m ShowS +prettyArithF pr d = \case + A_Prim op (BOp _ (A_Pair a b)) -> do + let (dop, (dopL, dopR)) = primOpPrec op + a' <- pr dopL a + b' <- pr dopR b + return $ showParen (d > dop) $ a' . showString " " . prettyPrimOp Infix op . showString " " . b' + A_Prim op (BLet ty rhs e) -> + pr d (BLet ty rhs (BOp (opType2 op) (A_Prim op e))) + A_Prim op arg -> do + arg' <- pr 11 arg + return $ showParen (d > 10) $ prettyPrimOp Prefix op . showString " " . arg' + A_Pair a b -> do + a' <- pr 0 a + b' <- pr 0 b + return $ showString "(" . a' . showString ", " . b' . showString ")" + A_If a b c -> do + a' <- pr 0 a + b' <- pr 0 b + c' <- pr 0 c + return $ showParen (d > 0) $ showString "if " . a' . showString " then " . b' . showString " else " . c' diff --git a/test-th/Arith/NonBase.hs b/test-th/Arith/NonBase.hs new file mode 100644 index 0000000..f5d458e --- /dev/null +++ b/test-th/Arith/NonBase.hs @@ -0,0 +1,50 @@ +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TemplateHaskell #-} +module Arith.NonBase where + +import Data.Kind +import Data.Type.Equality + +import NonBaseTH + + +data Typ t where + TInt :: Typ Int + TBool :: Typ Bool + TPair :: Typ a -> Typ b -> Typ (a, b) + TFun :: Typ a -> Typ b -> Typ (a -> b) +deriving instance Show (Typ t) + +instance TestEquality Typ where + testEquality TInt TInt = Just Refl + testEquality TBool TBool = Just Refl + testEquality (TPair a b) (TPair a' b') + | Just Refl <- testEquality a a' + , Just Refl <- testEquality b b' + = Just Refl + testEquality (TFun a b) (TFun a' b') + | Just Refl <- testEquality a a' + , Just Refl <- testEquality b b' + = Just Refl + testEquality _ _ = Nothing + +data PrimOp a b where + POAddI :: PrimOp (Int, Int) Int + POMulI :: PrimOp (Int, Int) Int + POEqI :: PrimOp (Int, Int) Bool +deriving instance Show (PrimOp a b) + +type Arith :: Type -> Type +data Arith t where + A_Var :: Typ t -> String -> Arith t + A_Let :: String -> Typ a -> Arith a -> Arith b -> Arith b + A_Prim :: PrimOp a b -> Arith a -> Arith b + A_Pair :: Arith a -> Arith b -> Arith (a, b) + A_If :: Arith Bool -> Arith a -> Arith a -> Arith a + A_Mono :: Arith Bool -> Arith Bool + +defineBaseAST + "ArithF" ''Arith ['A_Var, 'A_Let] (("AF_"++) . drop 2) + "arithConv" ''Typ (\_ _ _ -> [| error "Lambda impossible" |]) diff --git a/test-th/Main.hs b/test-th/Main.hs new file mode 100644 index 0000000..5a4d335 --- /dev/null +++ b/test-th/Main.hs @@ -0,0 +1,51 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} +module Main where + +import Data.Expr.SharingRecovery +import Data.Expr.SharingRecovery.Internal + +import Arith + + +-- TODO: test cyclic expressions + + +a_bin :: (KnownType a, KnownType b, KnownType c) + => PrimOp (a, b) c + -> PHOASExpr Typ v ArithF a + -> PHOASExpr Typ v ArithF b + -> PHOASExpr Typ v ArithF c +a_bin op a b = PHOASOp τ (A_Prim op (PHOASOp τ (A_Pair a b))) + +lam :: (KnownType a, KnownType b) + => (PHOASExpr Typ v f a -> PHOASExpr Typ v f b) -> PHOASExpr Typ v f (a -> b) +lam f = PHOASLam τ τ $ \arg -> f (PHOASVar τ arg) + +(+!) :: PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int +(+!) = a_bin POAddI + +(*!) :: PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int +(*!) = a_bin POMulI + +-- λx. x + x +ea_1 :: PHOASExpr Typ v ArithF (Int -> Int) +ea_1 = lam $ \arg -> arg +! arg + +-- λx. let y = x + x in y * y +ea_2 :: PHOASExpr Typ v ArithF (Int -> Int) +ea_2 = lam $ \arg -> let y = arg +! arg + in y *! y + +ea_3 :: PHOASExpr Typ v ArithF (Int -> Int) +ea_3 = lam $ \arg -> + let y = arg +! arg + x = y *! arg + -- in (y +! x) +! (x +! y) + in (x +! y) +! (y +! x) + +main :: IO () +main = putStrLn $ prettyBExpr prettyArithF (sharingRecovery ea_3) diff --git a/test-th/NonBaseTH.hs b/test-th/NonBaseTH.hs new file mode 100644 index 0000000..4741ea0 --- /dev/null +++ b/test-th/NonBaseTH.hs @@ -0,0 +1,225 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TemplateHaskellQuotes #-} +{-# LANGUAGE TupleSections #-} +module NonBaseTH ( + defineBaseAST, +) where + +import Control.Monad (when) +import Data.List (sort, foldl', tails) +import qualified Data.Map.Strict as Map +import Data.Maybe (catMaybes) +import Language.Haskell.TH + +import Data.Expr.SharingRecovery + + +-- | Non-base AST data type +data NBAST = NBAST String [TyVarBndr ()] [NBCon] + +data NBCon = NBCon + [String] -- ^ Names of the constructors defined with this clause + [TyVarBndrSpec] -- ^ Type variables in scope in the clause + Cxt -- ^ Constraints on the constructor type + [(Bang, NBField Type)] -- ^ Constructor fields + [Name] -- ^ All but the last type parameter in the result type + Type -- ^ The last type parameter in the result type + +data NBField t + = NBFList (NBField t) + | NBFTuple [NBField t] + | NBFRecur t -- ^ The last type parameter of this recursive position. (The + -- other type parameters are fixed anyway.) + | NBFConst Type + +parseNBAST :: [String] -> Info -> Q NBAST +parseNBAST excludes info = do + (astname, params, constrs) <- case info of + TyConI (DataD [] astname params Nothing constrs _) -> return (astname, params, constrs) + _ -> fail "Unsupported datatype" + + -- In this function we use parameter/index terminology: parameters are + -- uniform, indices vary inside an AST. + let parseField retpars field = do + let (core, args) = splitApps field + case core of + ConT n + | n == astname -> + if not (null args) && init args == map VarT retpars + then return (NBFRecur (last args)) + else fail $ "Field\n " ++ pprint field + ++ "\nis recursive, but with different type parameters than " + ++ "the return type of this constructor. All but the last type " + ++ "parameter of the GADT must be uniform over the entire AST." + + ListT + | [arg] <- args -> NBFList <$> parseField retpars arg + + TupleT k + | length args == k -> NBFTuple <$> traverse (parseField retpars) args + + _ -> do + when (pprint astname `infixOf` pprint field) $ + reportWarning $ "Field\n " ++ pprint field ++ "\nseems to refer to " + ++ pprint astname ++ " in unsupported ways; ignoring those occurrences." + return (NBFConst field) + + let splitConstr (ForallC vars ctx (GadtC names fields retty)) + | names'@(_:_) <- filter (`notElem` excludes) (map nameBase names) = + return (Just (vars, ctx, names', fields, retty)) + | otherwise = return Nothing + splitConstr c@GadtC{} = splitConstr (ForallC [] [] c) + splitConstr c = + let names = case c of + NormalC n _ -> Just (show n) + _ -> Just (show c) + in fail $ "Unsupported constructors found" ++ maybe "" (\s -> ": " ++ show s) names + + let parseConstr (vars, ctx, names, fields, retty) = do + (retpars, retindex) <- parseRetty astname (head names) retty + fields' <- traverse (\(ba, t) -> (ba,) <$> parseField retpars t) fields + return (NBCon names vars ctx fields' retpars retindex) + + constrs' <- traverse parseConstr =<< catMaybes <$> traverse splitConstr constrs + return (NBAST (nameBase astname) params constrs') + + +-- | Define a new GADT that is a base-functor-like version of a given existing +-- GADT AST. +-- +-- Remember to use 'lookupTypeName' or 'lookupValueName' instead of normal +-- quotes in case of punning of data types and constructors. +defineBaseAST + :: String -- ^ Name of the (base-functor-like) data type to define + -> Name -- ^ Name of the GADT to process + -> [Name] -- ^ Constructors to exclude (Var and Let, plus any other scoping construct) + -> (String -> String) -- ^ Constructor renaming function + -> String -- ^ Name of base -> nonbase conversion function to define + -> Name -- ^ Type of singleton types + -> (ExpQ -> ExpQ -> ExpQ -> ExpQ) -- ^ Lambda: typ (a -> b) -> typ a -> AST b -> AST (a -> b) + -> (ExpQ -> ExpQ -> ExpQ -> ExpQ) -- ^ Let: typ a -> String -> AST a -> AST b -> AST b + -> (ExpQ -> ExpQ -> ExpQ -> ExpQ) -- ^ Var: typ t -> String -> AST b -> AST b + -> Q [Dec] +defineBaseAST basename astname excludes renameConstr bnConvName typName mkLam mkLet mkVar = do + NBAST _ params constrs <- parseNBAST (map nameBase excludes) =<< reify astname + + -- Build the data type + + let basename' = mkName basename + conNameMap = Map.fromList [(nbname, mkName (renameConstr nbname)) + | NBCon ns _ _ _ _ _ <- constrs, nbname <- ns] + + let recvar = mkName "r" + + let processField (NBFRecur idx) = VarT recvar `AppT` idx + processField (NBFConst t) = t + processField (NBFList f) = AppT ListT (processField f) + processField (NBFTuple fs) = foldl' AppT (TupleT (length fs)) (map processField fs) + + let processConstr (NBCon names vars ctx fields retparams retindex) = do + let names' = map (conNameMap Map.!) names + let fields' = map (\(ba, f) -> (ba, processField f)) fields + let retty' = foldl' AppT (ConT basename') (map VarT retparams ++ [VarT recvar, retindex]) + return [ForallC (map cleanupBndr vars ++ [PlainTV recvar SpecifiedSpec]) + ctx (GadtC names' fields' retty')] + + let params' = map cleanupBndr (init params ++ [PlainTV recvar (), last params]) + constrs' <- concat <$> traverse processConstr constrs + let datatype = DataD [] (mkName basename) params' Nothing constrs' [] + + -- Build the B->N conversion function + + let bnConvName' = mkName bnConvName + envparam = VarT (mkName "env") + tparam = VarT (mkName "t") + ftype = foldl' AppT (ConT basename') (map (VarT . bndrName) (init params)) + arrow a b = ArrowT `AppT` a `AppT` b + + let bnConvSig = + SigD bnConvName' $ + (ConT ''BExpr `AppT` ConT typName `AppT` envparam `AppT` ftype `AppT` tparam) + `arrow` + foldl' AppT (ConT astname) (map (VarT . bndrName) (init params) ++ [tparam]) + + let clause' pats ex = Clause pats (NormalB ex) [] + -- backConMap = Map.fromList [(renameConstr nbname, nbname) + -- | NBCon ns _ _ _ _ _ <- constrs, nbname <- ns] + reconstructField (NBFRecur _) = do + r <- newName "r" + return (VarP r, VarE bnConvName' `AppE` VarE r) + reconstructField (NBFConst _) = do + x <- newName "x" + return (VarP x, VarE x) + reconstructField (NBFList f) = do + (pat, ex) <- reconstructField f + l <- newName "l" + return (VarP l, VarE 'map `AppE` LamE [pat] ex `AppE` VarE l) + reconstructField (NBFTuple fs) = do + (pats, exps) <- unzip <$> traverse reconstructField fs + return (TupP pats, TupE (map Just exps)) + + let tyarg1 = mkName "ty1" + tyarg2 = mkName "ty2" + astarg = mkName "ast" + mkseq a b = VarE 'seq `AppE` a `AppE` b + infixr `mkseq` + + bnConvFun <- + fmap (FunD bnConvName') . sequence $ + [do (pats, exps) <- unzip <$> traverse (reconstructField . snd) fields + return $ + clause' [ConP 'BOp [] [WildP, ConP (conNameMap Map.! nbname) [] pats]] + (foldl' AppE (ConE (mkName nbname)) exps) + | NBCon names _ _ fields _ _ <- constrs + , nbname <- names] + ++ + [do body <- mkLam (varE tyarg1) (varE tyarg2) (varE astarg) + return $ clause' [ConP 'BLam [] [VarP tyarg1, VarP tyarg2, VarP astarg]] + -- put them in a useless seq so that they're not unused + (TupE [Just (VarE tyarg1), Just (VarE tyarg2), Just (VarE astarg)] `mkseq` body)] + + return [datatype, bnConvSig, bnConvFun] + +-- | Remove `:: Type` annotations because they unnecessarily require the user +-- to enable KindSignatures. Any other annotations we leave, in case the user +-- wrote them and they are necessary. +cleanupBndr :: TyVarBndr a -> TyVarBndr a +cleanupBndr (KindedTV name x k) | isType k = PlainTV name x + where isType StarT = True + isType (ConT n) | n == ''Type = True + isType _ = False +cleanupBndr b = b + +bndrName :: TyVarBndr a -> Name +bndrName (PlainTV name _) = name +bndrName (KindedTV name _ _) = name + +parseRetty :: Name -> String -> Type -> Q ([Name], Type) +parseRetty astname consname retty = do + case splitApps retty of + (ConT name, args) + | name /= astname -> fail $ "Could not parse return type of constructor " ++ consname + | null args -> fail "Expected GADT to have type parameters" + + | Just varnames <- traverse (\case VarT varname -> Just varname ; _ -> Nothing) (init args) + , allDistinct varnames -> + return (varnames, last args) + + | otherwise -> fail $ "All type parameters but the last one must be uniform over all constructors. " + ++ "(Return type of constructor " ++ consname ++ ")" + _ -> fail $ "Could not parse return type of constructor " ++ consname + +splitApps :: Type -> (Type, [Type]) +splitApps = flip go [] + where go (ParensT t) tl = go t tl + go (AppT t arg) tl = go t (arg : tl) + go t tl = (t, tl) + +allDistinct :: Ord a => [a] -> Bool +allDistinct l = + let sorted = sort l + in all (uncurry (/=)) (zip sorted (drop 1 sorted)) + +infixOf :: Eq a => [a] -> [a] -> Bool +short `infixOf` long = any (`startsWith` short) (tails long) + where a `startsWith` b = take (length b) a == b diff --git a/test/Arith.hs b/test/Arith.hs deleted file mode 100644 index c34baa8..0000000 --- a/test/Arith.hs +++ /dev/null @@ -1,103 +0,0 @@ -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE QuantifiedConstraints #-} -module Arith where - -import Data.Type.Equality - -import Data.Expr.SharingRecovery - - -data Typ t where - TInt :: Typ Int - TBool :: Typ Bool - TPair :: Typ a -> Typ b -> Typ (a, b) - TFun :: Typ a -> Typ b -> Typ (a -> b) -deriving instance Show (Typ t) - -instance TestEquality Typ where - testEquality TInt TInt = Just Refl - testEquality TBool TBool = Just Refl - testEquality (TPair a b) (TPair a' b') - | Just Refl <- testEquality a a' - , Just Refl <- testEquality b b' - = Just Refl - testEquality (TFun a b) (TFun a' b') - | Just Refl <- testEquality a a' - , Just Refl <- testEquality b b' - = Just Refl - testEquality _ _ = Nothing - -class KnownType t where τ :: Typ t -instance KnownType Int where τ = TInt -instance KnownType Bool where τ = TBool -instance (KnownType a, KnownType b) => KnownType (a, b) where τ = TPair τ τ -instance (KnownType a, KnownType b) => KnownType (a -> b) where τ = TFun τ τ - -data PrimOp a b where - POAddI :: PrimOp (Int, Int) Int - POMulI :: PrimOp (Int, Int) Int - POEqI :: PrimOp (Int, Int) Bool -deriving instance Show (PrimOp a b) - -opType2 :: PrimOp a b -> Typ b -opType2 = \case - POAddI -> TInt - POMulI -> TInt - POEqI -> TBool - -data Fixity = Infix | Prefix - deriving (Show) - -primOpPrec :: PrimOp a b -> (Int, (Int, Int)) -primOpPrec POAddI = (6, (6, 7)) -primOpPrec POMulI = (7, (7, 8)) -primOpPrec POEqI = (4, (5, 5)) - -prettyPrimOp :: Fixity -> PrimOp a b -> ShowS -prettyPrimOp fix op = - let s = case op of - POAddI -> "+" - POMulI -> "*" - POEqI -> "==" - in showString $ case fix of - Infix -> s - Prefix -> "(" ++ s ++ ")" - -data ArithF r t where - A_Prim :: PrimOp a b -> r a -> ArithF r b - A_Pair :: r a -> r b -> ArithF r (a, b) - A_If :: r Bool -> r a -> r a -> ArithF r a -deriving instance (forall a. Show (r a)) => Show (ArithF r t) - -instance Functor1 ArithF -instance Traversable1 ArithF where - traverse1 f (A_Prim op x) = A_Prim op <$> f x - traverse1 f (A_Pair x y) = A_Pair <$> f x <*> f y - traverse1 f (A_If x y z) = A_If <$> f x <*> f y <*> f z - -prettyArithF :: Monad m - => (forall a. Int -> BExpr Typ env ArithF a -> m ShowS) - -> Int -> ArithF (BExpr Typ env ArithF) t -> m ShowS -prettyArithF pr d = \case - A_Prim op (BOp _ (A_Pair a b)) -> do - let (dop, (dopL, dopR)) = primOpPrec op - a' <- pr dopL a - b' <- pr dopR b - return $ showParen (d > dop) $ a' . showString " " . prettyPrimOp Infix op . showString " " . b' - A_Prim op (BLet ty rhs e) -> - pr d (BLet ty rhs (BOp (opType2 op) (A_Prim op e))) - A_Prim op arg -> do - arg' <- pr 11 arg - return $ showParen (d > 10) $ prettyPrimOp Prefix op . showString " " . arg' - A_Pair a b -> do - a' <- pr 0 a - b' <- pr 0 b - return $ showString "(" . a' . showString ", " . b' . showString ")" - A_If a b c -> do - a' <- pr 0 a - b' <- pr 0 b - c' <- pr 0 c - return $ showParen (d > 0) $ showString "if " . a' . showString " then " . b' . showString " else " . c' diff --git a/test/Arith/NonBase.hs b/test/Arith/NonBase.hs deleted file mode 100644 index f5d458e..0000000 --- a/test/Arith/NonBase.hs +++ /dev/null @@ -1,50 +0,0 @@ -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TemplateHaskell #-} -module Arith.NonBase where - -import Data.Kind -import Data.Type.Equality - -import NonBaseTH - - -data Typ t where - TInt :: Typ Int - TBool :: Typ Bool - TPair :: Typ a -> Typ b -> Typ (a, b) - TFun :: Typ a -> Typ b -> Typ (a -> b) -deriving instance Show (Typ t) - -instance TestEquality Typ where - testEquality TInt TInt = Just Refl - testEquality TBool TBool = Just Refl - testEquality (TPair a b) (TPair a' b') - | Just Refl <- testEquality a a' - , Just Refl <- testEquality b b' - = Just Refl - testEquality (TFun a b) (TFun a' b') - | Just Refl <- testEquality a a' - , Just Refl <- testEquality b b' - = Just Refl - testEquality _ _ = Nothing - -data PrimOp a b where - POAddI :: PrimOp (Int, Int) Int - POMulI :: PrimOp (Int, Int) Int - POEqI :: PrimOp (Int, Int) Bool -deriving instance Show (PrimOp a b) - -type Arith :: Type -> Type -data Arith t where - A_Var :: Typ t -> String -> Arith t - A_Let :: String -> Typ a -> Arith a -> Arith b -> Arith b - A_Prim :: PrimOp a b -> Arith a -> Arith b - A_Pair :: Arith a -> Arith b -> Arith (a, b) - A_If :: Arith Bool -> Arith a -> Arith a -> Arith a - A_Mono :: Arith Bool -> Arith Bool - -defineBaseAST - "ArithF" ''Arith ['A_Var, 'A_Let] (("AF_"++) . drop 2) - "arithConv" ''Typ (\_ _ _ -> [| error "Lambda impossible" |]) diff --git a/test/Main.hs b/test/Main.hs deleted file mode 100644 index 5a4d335..0000000 --- a/test/Main.hs +++ /dev/null @@ -1,51 +0,0 @@ -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE StandaloneDeriving #-} -module Main where - -import Data.Expr.SharingRecovery -import Data.Expr.SharingRecovery.Internal - -import Arith - - --- TODO: test cyclic expressions - - -a_bin :: (KnownType a, KnownType b, KnownType c) - => PrimOp (a, b) c - -> PHOASExpr Typ v ArithF a - -> PHOASExpr Typ v ArithF b - -> PHOASExpr Typ v ArithF c -a_bin op a b = PHOASOp τ (A_Prim op (PHOASOp τ (A_Pair a b))) - -lam :: (KnownType a, KnownType b) - => (PHOASExpr Typ v f a -> PHOASExpr Typ v f b) -> PHOASExpr Typ v f (a -> b) -lam f = PHOASLam τ τ $ \arg -> f (PHOASVar τ arg) - -(+!) :: PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -(+!) = a_bin POAddI - -(*!) :: PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -(*!) = a_bin POMulI - --- λx. x + x -ea_1 :: PHOASExpr Typ v ArithF (Int -> Int) -ea_1 = lam $ \arg -> arg +! arg - --- λx. let y = x + x in y * y -ea_2 :: PHOASExpr Typ v ArithF (Int -> Int) -ea_2 = lam $ \arg -> let y = arg +! arg - in y *! y - -ea_3 :: PHOASExpr Typ v ArithF (Int -> Int) -ea_3 = lam $ \arg -> - let y = arg +! arg - x = y *! arg - -- in (y +! x) +! (x +! y) - in (x +! y) +! (y +! x) - -main :: IO () -main = putStrLn $ prettyBExpr prettyArithF (sharingRecovery ea_3) diff --git a/test/NonBaseTH.hs b/test/NonBaseTH.hs deleted file mode 100644 index 4741ea0..0000000 --- a/test/NonBaseTH.hs +++ /dev/null @@ -1,225 +0,0 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE TemplateHaskellQuotes #-} -{-# LANGUAGE TupleSections #-} -module NonBaseTH ( - defineBaseAST, -) where - -import Control.Monad (when) -import Data.List (sort, foldl', tails) -import qualified Data.Map.Strict as Map -import Data.Maybe (catMaybes) -import Language.Haskell.TH - -import Data.Expr.SharingRecovery - - --- | Non-base AST data type -data NBAST = NBAST String [TyVarBndr ()] [NBCon] - -data NBCon = NBCon - [String] -- ^ Names of the constructors defined with this clause - [TyVarBndrSpec] -- ^ Type variables in scope in the clause - Cxt -- ^ Constraints on the constructor type - [(Bang, NBField Type)] -- ^ Constructor fields - [Name] -- ^ All but the last type parameter in the result type - Type -- ^ The last type parameter in the result type - -data NBField t - = NBFList (NBField t) - | NBFTuple [NBField t] - | NBFRecur t -- ^ The last type parameter of this recursive position. (The - -- other type parameters are fixed anyway.) - | NBFConst Type - -parseNBAST :: [String] -> Info -> Q NBAST -parseNBAST excludes info = do - (astname, params, constrs) <- case info of - TyConI (DataD [] astname params Nothing constrs _) -> return (astname, params, constrs) - _ -> fail "Unsupported datatype" - - -- In this function we use parameter/index terminology: parameters are - -- uniform, indices vary inside an AST. - let parseField retpars field = do - let (core, args) = splitApps field - case core of - ConT n - | n == astname -> - if not (null args) && init args == map VarT retpars - then return (NBFRecur (last args)) - else fail $ "Field\n " ++ pprint field - ++ "\nis recursive, but with different type parameters than " - ++ "the return type of this constructor. All but the last type " - ++ "parameter of the GADT must be uniform over the entire AST." - - ListT - | [arg] <- args -> NBFList <$> parseField retpars arg - - TupleT k - | length args == k -> NBFTuple <$> traverse (parseField retpars) args - - _ -> do - when (pprint astname `infixOf` pprint field) $ - reportWarning $ "Field\n " ++ pprint field ++ "\nseems to refer to " - ++ pprint astname ++ " in unsupported ways; ignoring those occurrences." - return (NBFConst field) - - let splitConstr (ForallC vars ctx (GadtC names fields retty)) - | names'@(_:_) <- filter (`notElem` excludes) (map nameBase names) = - return (Just (vars, ctx, names', fields, retty)) - | otherwise = return Nothing - splitConstr c@GadtC{} = splitConstr (ForallC [] [] c) - splitConstr c = - let names = case c of - NormalC n _ -> Just (show n) - _ -> Just (show c) - in fail $ "Unsupported constructors found" ++ maybe "" (\s -> ": " ++ show s) names - - let parseConstr (vars, ctx, names, fields, retty) = do - (retpars, retindex) <- parseRetty astname (head names) retty - fields' <- traverse (\(ba, t) -> (ba,) <$> parseField retpars t) fields - return (NBCon names vars ctx fields' retpars retindex) - - constrs' <- traverse parseConstr =<< catMaybes <$> traverse splitConstr constrs - return (NBAST (nameBase astname) params constrs') - - --- | Define a new GADT that is a base-functor-like version of a given existing --- GADT AST. --- --- Remember to use 'lookupTypeName' or 'lookupValueName' instead of normal --- quotes in case of punning of data types and constructors. -defineBaseAST - :: String -- ^ Name of the (base-functor-like) data type to define - -> Name -- ^ Name of the GADT to process - -> [Name] -- ^ Constructors to exclude (Var and Let, plus any other scoping construct) - -> (String -> String) -- ^ Constructor renaming function - -> String -- ^ Name of base -> nonbase conversion function to define - -> Name -- ^ Type of singleton types - -> (ExpQ -> ExpQ -> ExpQ -> ExpQ) -- ^ Lambda: typ (a -> b) -> typ a -> AST b -> AST (a -> b) - -> (ExpQ -> ExpQ -> ExpQ -> ExpQ) -- ^ Let: typ a -> String -> AST a -> AST b -> AST b - -> (ExpQ -> ExpQ -> ExpQ -> ExpQ) -- ^ Var: typ t -> String -> AST b -> AST b - -> Q [Dec] -defineBaseAST basename astname excludes renameConstr bnConvName typName mkLam mkLet mkVar = do - NBAST _ params constrs <- parseNBAST (map nameBase excludes) =<< reify astname - - -- Build the data type - - let basename' = mkName basename - conNameMap = Map.fromList [(nbname, mkName (renameConstr nbname)) - | NBCon ns _ _ _ _ _ <- constrs, nbname <- ns] - - let recvar = mkName "r" - - let processField (NBFRecur idx) = VarT recvar `AppT` idx - processField (NBFConst t) = t - processField (NBFList f) = AppT ListT (processField f) - processField (NBFTuple fs) = foldl' AppT (TupleT (length fs)) (map processField fs) - - let processConstr (NBCon names vars ctx fields retparams retindex) = do - let names' = map (conNameMap Map.!) names - let fields' = map (\(ba, f) -> (ba, processField f)) fields - let retty' = foldl' AppT (ConT basename') (map VarT retparams ++ [VarT recvar, retindex]) - return [ForallC (map cleanupBndr vars ++ [PlainTV recvar SpecifiedSpec]) - ctx (GadtC names' fields' retty')] - - let params' = map cleanupBndr (init params ++ [PlainTV recvar (), last params]) - constrs' <- concat <$> traverse processConstr constrs - let datatype = DataD [] (mkName basename) params' Nothing constrs' [] - - -- Build the B->N conversion function - - let bnConvName' = mkName bnConvName - envparam = VarT (mkName "env") - tparam = VarT (mkName "t") - ftype = foldl' AppT (ConT basename') (map (VarT . bndrName) (init params)) - arrow a b = ArrowT `AppT` a `AppT` b - - let bnConvSig = - SigD bnConvName' $ - (ConT ''BExpr `AppT` ConT typName `AppT` envparam `AppT` ftype `AppT` tparam) - `arrow` - foldl' AppT (ConT astname) (map (VarT . bndrName) (init params) ++ [tparam]) - - let clause' pats ex = Clause pats (NormalB ex) [] - -- backConMap = Map.fromList [(renameConstr nbname, nbname) - -- | NBCon ns _ _ _ _ _ <- constrs, nbname <- ns] - reconstructField (NBFRecur _) = do - r <- newName "r" - return (VarP r, VarE bnConvName' `AppE` VarE r) - reconstructField (NBFConst _) = do - x <- newName "x" - return (VarP x, VarE x) - reconstructField (NBFList f) = do - (pat, ex) <- reconstructField f - l <- newName "l" - return (VarP l, VarE 'map `AppE` LamE [pat] ex `AppE` VarE l) - reconstructField (NBFTuple fs) = do - (pats, exps) <- unzip <$> traverse reconstructField fs - return (TupP pats, TupE (map Just exps)) - - let tyarg1 = mkName "ty1" - tyarg2 = mkName "ty2" - astarg = mkName "ast" - mkseq a b = VarE 'seq `AppE` a `AppE` b - infixr `mkseq` - - bnConvFun <- - fmap (FunD bnConvName') . sequence $ - [do (pats, exps) <- unzip <$> traverse (reconstructField . snd) fields - return $ - clause' [ConP 'BOp [] [WildP, ConP (conNameMap Map.! nbname) [] pats]] - (foldl' AppE (ConE (mkName nbname)) exps) - | NBCon names _ _ fields _ _ <- constrs - , nbname <- names] - ++ - [do body <- mkLam (varE tyarg1) (varE tyarg2) (varE astarg) - return $ clause' [ConP 'BLam [] [VarP tyarg1, VarP tyarg2, VarP astarg]] - -- put them in a useless seq so that they're not unused - (TupE [Just (VarE tyarg1), Just (VarE tyarg2), Just (VarE astarg)] `mkseq` body)] - - return [datatype, bnConvSig, bnConvFun] - --- | Remove `:: Type` annotations because they unnecessarily require the user --- to enable KindSignatures. Any other annotations we leave, in case the user --- wrote them and they are necessary. -cleanupBndr :: TyVarBndr a -> TyVarBndr a -cleanupBndr (KindedTV name x k) | isType k = PlainTV name x - where isType StarT = True - isType (ConT n) | n == ''Type = True - isType _ = False -cleanupBndr b = b - -bndrName :: TyVarBndr a -> Name -bndrName (PlainTV name _) = name -bndrName (KindedTV name _ _) = name - -parseRetty :: Name -> String -> Type -> Q ([Name], Type) -parseRetty astname consname retty = do - case splitApps retty of - (ConT name, args) - | name /= astname -> fail $ "Could not parse return type of constructor " ++ consname - | null args -> fail "Expected GADT to have type parameters" - - | Just varnames <- traverse (\case VarT varname -> Just varname ; _ -> Nothing) (init args) - , allDistinct varnames -> - return (varnames, last args) - - | otherwise -> fail $ "All type parameters but the last one must be uniform over all constructors. " - ++ "(Return type of constructor " ++ consname ++ ")" - _ -> fail $ "Could not parse return type of constructor " ++ consname - -splitApps :: Type -> (Type, [Type]) -splitApps = flip go [] - where go (ParensT t) tl = go t tl - go (AppT t arg) tl = go t (arg : tl) - go t tl = (t, tl) - -allDistinct :: Ord a => [a] -> Bool -allDistinct l = - let sorted = sort l - in all (uncurry (/=)) (zip sorted (drop 1 sorted)) - -infixOf :: Eq a => [a] -> [a] -> Bool -short `infixOf` long = any (`startsWith` short) (tails long) - where a `startsWith` b = take (length b) a == b -- cgit v1.2.3-70-g09d2