aboutsummaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-10-03 23:05:24 +0200
committerTom Smeding <tom@tomsmeding.com>2025-10-03 23:05:24 +0200
commit4772025626d78127536c341c38052d23ca953ae3 (patch)
tree56374b80987c42598b63b785ba8207bc290cc835 /test
parentbb44859684ee8f241da6d2d0a4ebed1639b11b81 (diff)
Move TH experiments to test-th
Diffstat (limited to 'test')
-rw-r--r--test/Arith.hs103
-rw-r--r--test/Arith/NonBase.hs50
-rw-r--r--test/Main.hs51
-rw-r--r--test/NonBaseTH.hs225
4 files changed, 0 insertions, 429 deletions
diff --git a/test/Arith.hs b/test/Arith.hs
deleted file mode 100644
index c34baa8..0000000
--- a/test/Arith.hs
+++ /dev/null
@@ -1,103 +0,0 @@
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE QuantifiedConstraints #-}
-module Arith where
-
-import Data.Type.Equality
-
-import Data.Expr.SharingRecovery
-
-
-data Typ t where
- TInt :: Typ Int
- TBool :: Typ Bool
- TPair :: Typ a -> Typ b -> Typ (a, b)
- TFun :: Typ a -> Typ b -> Typ (a -> b)
-deriving instance Show (Typ t)
-
-instance TestEquality Typ where
- testEquality TInt TInt = Just Refl
- testEquality TBool TBool = Just Refl
- testEquality (TPair a b) (TPair a' b')
- | Just Refl <- testEquality a a'
- , Just Refl <- testEquality b b'
- = Just Refl
- testEquality (TFun a b) (TFun a' b')
- | Just Refl <- testEquality a a'
- , Just Refl <- testEquality b b'
- = Just Refl
- testEquality _ _ = Nothing
-
-class KnownType t where τ :: Typ t
-instance KnownType Int where τ = TInt
-instance KnownType Bool where τ = TBool
-instance (KnownType a, KnownType b) => KnownType (a, b) where τ = TPair τ τ
-instance (KnownType a, KnownType b) => KnownType (a -> b) where τ = TFun τ τ
-
-data PrimOp a b where
- POAddI :: PrimOp (Int, Int) Int
- POMulI :: PrimOp (Int, Int) Int
- POEqI :: PrimOp (Int, Int) Bool
-deriving instance Show (PrimOp a b)
-
-opType2 :: PrimOp a b -> Typ b
-opType2 = \case
- POAddI -> TInt
- POMulI -> TInt
- POEqI -> TBool
-
-data Fixity = Infix | Prefix
- deriving (Show)
-
-primOpPrec :: PrimOp a b -> (Int, (Int, Int))
-primOpPrec POAddI = (6, (6, 7))
-primOpPrec POMulI = (7, (7, 8))
-primOpPrec POEqI = (4, (5, 5))
-
-prettyPrimOp :: Fixity -> PrimOp a b -> ShowS
-prettyPrimOp fix op =
- let s = case op of
- POAddI -> "+"
- POMulI -> "*"
- POEqI -> "=="
- in showString $ case fix of
- Infix -> s
- Prefix -> "(" ++ s ++ ")"
-
-data ArithF r t where
- A_Prim :: PrimOp a b -> r a -> ArithF r b
- A_Pair :: r a -> r b -> ArithF r (a, b)
- A_If :: r Bool -> r a -> r a -> ArithF r a
-deriving instance (forall a. Show (r a)) => Show (ArithF r t)
-
-instance Functor1 ArithF
-instance Traversable1 ArithF where
- traverse1 f (A_Prim op x) = A_Prim op <$> f x
- traverse1 f (A_Pair x y) = A_Pair <$> f x <*> f y
- traverse1 f (A_If x y z) = A_If <$> f x <*> f y <*> f z
-
-prettyArithF :: Monad m
- => (forall a. Int -> BExpr Typ env ArithF a -> m ShowS)
- -> Int -> ArithF (BExpr Typ env ArithF) t -> m ShowS
-prettyArithF pr d = \case
- A_Prim op (BOp _ (A_Pair a b)) -> do
- let (dop, (dopL, dopR)) = primOpPrec op
- a' <- pr dopL a
- b' <- pr dopR b
- return $ showParen (d > dop) $ a' . showString " " . prettyPrimOp Infix op . showString " " . b'
- A_Prim op (BLet ty rhs e) ->
- pr d (BLet ty rhs (BOp (opType2 op) (A_Prim op e)))
- A_Prim op arg -> do
- arg' <- pr 11 arg
- return $ showParen (d > 10) $ prettyPrimOp Prefix op . showString " " . arg'
- A_Pair a b -> do
- a' <- pr 0 a
- b' <- pr 0 b
- return $ showString "(" . a' . showString ", " . b' . showString ")"
- A_If a b c -> do
- a' <- pr 0 a
- b' <- pr 0 b
- c' <- pr 0 c
- return $ showParen (d > 0) $ showString "if " . a' . showString " then " . b' . showString " else " . c'
diff --git a/test/Arith/NonBase.hs b/test/Arith/NonBase.hs
deleted file mode 100644
index f5d458e..0000000
--- a/test/Arith/NonBase.hs
+++ /dev/null
@@ -1,50 +0,0 @@
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE TemplateHaskell #-}
-module Arith.NonBase where
-
-import Data.Kind
-import Data.Type.Equality
-
-import NonBaseTH
-
-
-data Typ t where
- TInt :: Typ Int
- TBool :: Typ Bool
- TPair :: Typ a -> Typ b -> Typ (a, b)
- TFun :: Typ a -> Typ b -> Typ (a -> b)
-deriving instance Show (Typ t)
-
-instance TestEquality Typ where
- testEquality TInt TInt = Just Refl
- testEquality TBool TBool = Just Refl
- testEquality (TPair a b) (TPair a' b')
- | Just Refl <- testEquality a a'
- , Just Refl <- testEquality b b'
- = Just Refl
- testEquality (TFun a b) (TFun a' b')
- | Just Refl <- testEquality a a'
- , Just Refl <- testEquality b b'
- = Just Refl
- testEquality _ _ = Nothing
-
-data PrimOp a b where
- POAddI :: PrimOp (Int, Int) Int
- POMulI :: PrimOp (Int, Int) Int
- POEqI :: PrimOp (Int, Int) Bool
-deriving instance Show (PrimOp a b)
-
-type Arith :: Type -> Type
-data Arith t where
- A_Var :: Typ t -> String -> Arith t
- A_Let :: String -> Typ a -> Arith a -> Arith b -> Arith b
- A_Prim :: PrimOp a b -> Arith a -> Arith b
- A_Pair :: Arith a -> Arith b -> Arith (a, b)
- A_If :: Arith Bool -> Arith a -> Arith a -> Arith a
- A_Mono :: Arith Bool -> Arith Bool
-
-defineBaseAST
- "ArithF" ''Arith ['A_Var, 'A_Let] (("AF_"++) . drop 2)
- "arithConv" ''Typ (\_ _ _ -> [| error "Lambda impossible" |])
diff --git a/test/Main.hs b/test/Main.hs
deleted file mode 100644
index 5a4d335..0000000
--- a/test/Main.hs
+++ /dev/null
@@ -1,51 +0,0 @@
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE QuantifiedConstraints #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE StandaloneDeriving #-}
-module Main where
-
-import Data.Expr.SharingRecovery
-import Data.Expr.SharingRecovery.Internal
-
-import Arith
-
-
--- TODO: test cyclic expressions
-
-
-a_bin :: (KnownType a, KnownType b, KnownType c)
- => PrimOp (a, b) c
- -> PHOASExpr Typ v ArithF a
- -> PHOASExpr Typ v ArithF b
- -> PHOASExpr Typ v ArithF c
-a_bin op a b = PHOASOp τ (A_Prim op (PHOASOp τ (A_Pair a b)))
-
-lam :: (KnownType a, KnownType b)
- => (PHOASExpr Typ v f a -> PHOASExpr Typ v f b) -> PHOASExpr Typ v f (a -> b)
-lam f = PHOASLam τ τ $ \arg -> f (PHOASVar τ arg)
-
-(+!) :: PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int
-(+!) = a_bin POAddI
-
-(*!) :: PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int -> PHOASExpr Typ v ArithF Int
-(*!) = a_bin POMulI
-
--- λx. x + x
-ea_1 :: PHOASExpr Typ v ArithF (Int -> Int)
-ea_1 = lam $ \arg -> arg +! arg
-
--- λx. let y = x + x in y * y
-ea_2 :: PHOASExpr Typ v ArithF (Int -> Int)
-ea_2 = lam $ \arg -> let y = arg +! arg
- in y *! y
-
-ea_3 :: PHOASExpr Typ v ArithF (Int -> Int)
-ea_3 = lam $ \arg ->
- let y = arg +! arg
- x = y *! arg
- -- in (y +! x) +! (x +! y)
- in (x +! y) +! (y +! x)
-
-main :: IO ()
-main = putStrLn $ prettyBExpr prettyArithF (sharingRecovery ea_3)
diff --git a/test/NonBaseTH.hs b/test/NonBaseTH.hs
deleted file mode 100644
index 4741ea0..0000000
--- a/test/NonBaseTH.hs
+++ /dev/null
@@ -1,225 +0,0 @@
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE TemplateHaskellQuotes #-}
-{-# LANGUAGE TupleSections #-}
-module NonBaseTH (
- defineBaseAST,
-) where
-
-import Control.Monad (when)
-import Data.List (sort, foldl', tails)
-import qualified Data.Map.Strict as Map
-import Data.Maybe (catMaybes)
-import Language.Haskell.TH
-
-import Data.Expr.SharingRecovery
-
-
--- | Non-base AST data type
-data NBAST = NBAST String [TyVarBndr ()] [NBCon]
-
-data NBCon = NBCon
- [String] -- ^ Names of the constructors defined with this clause
- [TyVarBndrSpec] -- ^ Type variables in scope in the clause
- Cxt -- ^ Constraints on the constructor type
- [(Bang, NBField Type)] -- ^ Constructor fields
- [Name] -- ^ All but the last type parameter in the result type
- Type -- ^ The last type parameter in the result type
-
-data NBField t
- = NBFList (NBField t)
- | NBFTuple [NBField t]
- | NBFRecur t -- ^ The last type parameter of this recursive position. (The
- -- other type parameters are fixed anyway.)
- | NBFConst Type
-
-parseNBAST :: [String] -> Info -> Q NBAST
-parseNBAST excludes info = do
- (astname, params, constrs) <- case info of
- TyConI (DataD [] astname params Nothing constrs _) -> return (astname, params, constrs)
- _ -> fail "Unsupported datatype"
-
- -- In this function we use parameter/index terminology: parameters are
- -- uniform, indices vary inside an AST.
- let parseField retpars field = do
- let (core, args) = splitApps field
- case core of
- ConT n
- | n == astname ->
- if not (null args) && init args == map VarT retpars
- then return (NBFRecur (last args))
- else fail $ "Field\n " ++ pprint field
- ++ "\nis recursive, but with different type parameters than "
- ++ "the return type of this constructor. All but the last type "
- ++ "parameter of the GADT must be uniform over the entire AST."
-
- ListT
- | [arg] <- args -> NBFList <$> parseField retpars arg
-
- TupleT k
- | length args == k -> NBFTuple <$> traverse (parseField retpars) args
-
- _ -> do
- when (pprint astname `infixOf` pprint field) $
- reportWarning $ "Field\n " ++ pprint field ++ "\nseems to refer to "
- ++ pprint astname ++ " in unsupported ways; ignoring those occurrences."
- return (NBFConst field)
-
- let splitConstr (ForallC vars ctx (GadtC names fields retty))
- | names'@(_:_) <- filter (`notElem` excludes) (map nameBase names) =
- return (Just (vars, ctx, names', fields, retty))
- | otherwise = return Nothing
- splitConstr c@GadtC{} = splitConstr (ForallC [] [] c)
- splitConstr c =
- let names = case c of
- NormalC n _ -> Just (show n)
- _ -> Just (show c)
- in fail $ "Unsupported constructors found" ++ maybe "" (\s -> ": " ++ show s) names
-
- let parseConstr (vars, ctx, names, fields, retty) = do
- (retpars, retindex) <- parseRetty astname (head names) retty
- fields' <- traverse (\(ba, t) -> (ba,) <$> parseField retpars t) fields
- return (NBCon names vars ctx fields' retpars retindex)
-
- constrs' <- traverse parseConstr =<< catMaybes <$> traverse splitConstr constrs
- return (NBAST (nameBase astname) params constrs')
-
-
--- | Define a new GADT that is a base-functor-like version of a given existing
--- GADT AST.
---
--- Remember to use 'lookupTypeName' or 'lookupValueName' instead of normal
--- quotes in case of punning of data types and constructors.
-defineBaseAST
- :: String -- ^ Name of the (base-functor-like) data type to define
- -> Name -- ^ Name of the GADT to process
- -> [Name] -- ^ Constructors to exclude (Var and Let, plus any other scoping construct)
- -> (String -> String) -- ^ Constructor renaming function
- -> String -- ^ Name of base -> nonbase conversion function to define
- -> Name -- ^ Type of singleton types
- -> (ExpQ -> ExpQ -> ExpQ -> ExpQ) -- ^ Lambda: typ (a -> b) -> typ a -> AST b -> AST (a -> b)
- -> (ExpQ -> ExpQ -> ExpQ -> ExpQ) -- ^ Let: typ a -> String -> AST a -> AST b -> AST b
- -> (ExpQ -> ExpQ -> ExpQ -> ExpQ) -- ^ Var: typ t -> String -> AST b -> AST b
- -> Q [Dec]
-defineBaseAST basename astname excludes renameConstr bnConvName typName mkLam mkLet mkVar = do
- NBAST _ params constrs <- parseNBAST (map nameBase excludes) =<< reify astname
-
- -- Build the data type
-
- let basename' = mkName basename
- conNameMap = Map.fromList [(nbname, mkName (renameConstr nbname))
- | NBCon ns _ _ _ _ _ <- constrs, nbname <- ns]
-
- let recvar = mkName "r"
-
- let processField (NBFRecur idx) = VarT recvar `AppT` idx
- processField (NBFConst t) = t
- processField (NBFList f) = AppT ListT (processField f)
- processField (NBFTuple fs) = foldl' AppT (TupleT (length fs)) (map processField fs)
-
- let processConstr (NBCon names vars ctx fields retparams retindex) = do
- let names' = map (conNameMap Map.!) names
- let fields' = map (\(ba, f) -> (ba, processField f)) fields
- let retty' = foldl' AppT (ConT basename') (map VarT retparams ++ [VarT recvar, retindex])
- return [ForallC (map cleanupBndr vars ++ [PlainTV recvar SpecifiedSpec])
- ctx (GadtC names' fields' retty')]
-
- let params' = map cleanupBndr (init params ++ [PlainTV recvar (), last params])
- constrs' <- concat <$> traverse processConstr constrs
- let datatype = DataD [] (mkName basename) params' Nothing constrs' []
-
- -- Build the B->N conversion function
-
- let bnConvName' = mkName bnConvName
- envparam = VarT (mkName "env")
- tparam = VarT (mkName "t")
- ftype = foldl' AppT (ConT basename') (map (VarT . bndrName) (init params))
- arrow a b = ArrowT `AppT` a `AppT` b
-
- let bnConvSig =
- SigD bnConvName' $
- (ConT ''BExpr `AppT` ConT typName `AppT` envparam `AppT` ftype `AppT` tparam)
- `arrow`
- foldl' AppT (ConT astname) (map (VarT . bndrName) (init params) ++ [tparam])
-
- let clause' pats ex = Clause pats (NormalB ex) []
- -- backConMap = Map.fromList [(renameConstr nbname, nbname)
- -- | NBCon ns _ _ _ _ _ <- constrs, nbname <- ns]
- reconstructField (NBFRecur _) = do
- r <- newName "r"
- return (VarP r, VarE bnConvName' `AppE` VarE r)
- reconstructField (NBFConst _) = do
- x <- newName "x"
- return (VarP x, VarE x)
- reconstructField (NBFList f) = do
- (pat, ex) <- reconstructField f
- l <- newName "l"
- return (VarP l, VarE 'map `AppE` LamE [pat] ex `AppE` VarE l)
- reconstructField (NBFTuple fs) = do
- (pats, exps) <- unzip <$> traverse reconstructField fs
- return (TupP pats, TupE (map Just exps))
-
- let tyarg1 = mkName "ty1"
- tyarg2 = mkName "ty2"
- astarg = mkName "ast"
- mkseq a b = VarE 'seq `AppE` a `AppE` b
- infixr `mkseq`
-
- bnConvFun <-
- fmap (FunD bnConvName') . sequence $
- [do (pats, exps) <- unzip <$> traverse (reconstructField . snd) fields
- return $
- clause' [ConP 'BOp [] [WildP, ConP (conNameMap Map.! nbname) [] pats]]
- (foldl' AppE (ConE (mkName nbname)) exps)
- | NBCon names _ _ fields _ _ <- constrs
- , nbname <- names]
- ++
- [do body <- mkLam (varE tyarg1) (varE tyarg2) (varE astarg)
- return $ clause' [ConP 'BLam [] [VarP tyarg1, VarP tyarg2, VarP astarg]]
- -- put them in a useless seq so that they're not unused
- (TupE [Just (VarE tyarg1), Just (VarE tyarg2), Just (VarE astarg)] `mkseq` body)]
-
- return [datatype, bnConvSig, bnConvFun]
-
--- | Remove `:: Type` annotations because they unnecessarily require the user
--- to enable KindSignatures. Any other annotations we leave, in case the user
--- wrote them and they are necessary.
-cleanupBndr :: TyVarBndr a -> TyVarBndr a
-cleanupBndr (KindedTV name x k) | isType k = PlainTV name x
- where isType StarT = True
- isType (ConT n) | n == ''Type = True
- isType _ = False
-cleanupBndr b = b
-
-bndrName :: TyVarBndr a -> Name
-bndrName (PlainTV name _) = name
-bndrName (KindedTV name _ _) = name
-
-parseRetty :: Name -> String -> Type -> Q ([Name], Type)
-parseRetty astname consname retty = do
- case splitApps retty of
- (ConT name, args)
- | name /= astname -> fail $ "Could not parse return type of constructor " ++ consname
- | null args -> fail "Expected GADT to have type parameters"
-
- | Just varnames <- traverse (\case VarT varname -> Just varname ; _ -> Nothing) (init args)
- , allDistinct varnames ->
- return (varnames, last args)
-
- | otherwise -> fail $ "All type parameters but the last one must be uniform over all constructors. "
- ++ "(Return type of constructor " ++ consname ++ ")"
- _ -> fail $ "Could not parse return type of constructor " ++ consname
-
-splitApps :: Type -> (Type, [Type])
-splitApps = flip go []
- where go (ParensT t) tl = go t tl
- go (AppT t arg) tl = go t (arg : tl)
- go t tl = (t, tl)
-
-allDistinct :: Ord a => [a] -> Bool
-allDistinct l =
- let sorted = sort l
- in all (uncurry (/=)) (zip sorted (drop 1 sorted))
-
-infixOf :: Eq a => [a] -> [a] -> Bool
-short `infixOf` long = any (`startsWith` short) (tails long)
- where a `startsWith` b = take (length b) a == b