aboutsummaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-09-02 12:05:11 +0200
committerTom Smeding <tom@tomsmeding.com>2024-09-02 12:05:11 +0200
commitb4906aed78519919f556053c4ecc0b646df3cece (patch)
treec1a26161c4408d76bebeecf1d42957d10fdfb13a /test
parentecb5f11260e39be824b1b45c5bcb93bb9e137132 (diff)
More WIP THHEADmaster
Diffstat (limited to 'test')
-rw-r--r--test/Arith/NonBase.hs4
-rw-r--r--test/NonBaseTH.hs231
2 files changed, 129 insertions, 106 deletions
diff --git a/test/Arith/NonBase.hs b/test/Arith/NonBase.hs
index ea27863..f5d458e 100644
--- a/test/Arith/NonBase.hs
+++ b/test/Arith/NonBase.hs
@@ -45,4 +45,6 @@ data Arith t where
A_If :: Arith Bool -> Arith a -> Arith a -> Arith a
A_Mono :: Arith Bool -> Arith Bool
-defineBaseAST "ArithF" ''Arith ['A_Var, 'A_Let] (("AF_"++) . drop 2)
+defineBaseAST
+ "ArithF" ''Arith ['A_Var, 'A_Let] (("AF_"++) . drop 2)
+ "arithConv" ''Typ (\_ _ _ -> [| error "Lambda impossible" |])
diff --git a/test/NonBaseTH.hs b/test/NonBaseTH.hs
index c6cf54b..4741ea0 100644
--- a/test/NonBaseTH.hs
+++ b/test/NonBaseTH.hs
@@ -1,40 +1,39 @@
{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
-module NonBaseTH where
+{-# LANGUAGE TupleSections #-}
+module NonBaseTH (
+ defineBaseAST,
+) where
import Control.Monad (when)
import Data.List (sort, foldl', tails)
-import Data.Maybe (fromMaybe)
+import qualified Data.Map.Strict as Map
+import Data.Maybe (catMaybes)
import Language.Haskell.TH
+import Data.Expr.SharingRecovery
+
-- | Non-base AST data type
data NBAST = NBAST String [TyVarBndr ()] [NBCon]
data NBCon = NBCon
- [String] -- ^ Names of the constructors defined with this clause
- [TyVarBndrSpec] -- ^ Type variables in scope in the clause
- Cxt -- ^ Constraints on the constructor type
- [(Bang, NBField)] -- ^ Constructor fields
- [Name] -- ^ All but the last type parameter in the result type
- Type -- ^ The last type parameter in the result type
-
-data NBField
- = NBFRecur ([Type] -> Type)
- -- ^ Context: takes replacement recursive positions and wraps
- -- them in whatever structure they came in (e.g. a tuple).
- [Type]
- -- ^ The last type parameter of all the recursive positions in
- -- this field. (The other type parameters are fixed anyway.)
+ [String] -- ^ Names of the constructors defined with this clause
+ [TyVarBndrSpec] -- ^ Type variables in scope in the clause
+ Cxt -- ^ Constraints on the constructor type
+ [(Bang, NBField Type)] -- ^ Constructor fields
+ [Name] -- ^ All but the last type parameter in the result type
+ Type -- ^ The last type parameter in the result type
+
+data NBField t
+ = NBFList (NBField t)
+ | NBFTuple [NBField t]
+ | NBFRecur t -- ^ The last type parameter of this recursive position. (The
+ -- other type parameters are fixed anyway.)
| NBFConst Type
-wrapNBField :: (Type -> Type) -> NBField -> NBField
-wrapNBField f (NBFRecur ctx param) = NBFRecur (f . ctx) param
-wrapNBField f (NBFConst t) = NBFConst (f t)
-
-parseNBAST :: Info -> Q NBAST
-parseNBAST info = do
+parseNBAST :: [String] -> Info -> Q NBAST
+parseNBAST excludes info = do
(astname, params, constrs) <- case info of
TyConI (DataD [] astname params Nothing constrs _) -> return (astname, params, constrs)
_ -> fail "Unsupported datatype"
@@ -47,29 +46,17 @@ parseNBAST info = do
ConT n
| n == astname ->
if not (null args) && init args == map VarT retpars
- then return (NBFRecur head [last args])
+ then return (NBFRecur (last args))
else fail $ "Field\n " ++ pprint field
++ "\nis recursive, but with different type parameters than "
++ "the return type of this constructor. All but the last type "
++ "parameter of the GADT must be uniform over the entire AST."
ListT
- | [arg] <- args -> wrapNBField (ListT `AppT`) <$> parseField retpars arg
+ | [arg] <- args -> NBFList <$> parseField retpars arg
TupleT k
- | length args == k -> do
- positions <- traverse (parseField retpars) args
- case traverse (\case NBFConst t -> Just t ; _ -> Nothing) positions of
- Just l -> return (NBFConst (foldl' AppT (TupleT k) l))
- Nothing -> do
- let indices = concatMap (\case NBFRecur _ is -> is ; NBFConst _ -> []) positions
- let reconstruct [] [] = []
- reconstruct [] _ = error "Invalid number of replacing recursive positions"
- reconstruct (NBFRecur ctx pars : poss) repls =
- let (pre, post) = splitAt (length pars) repls
- in ctx pre : reconstruct poss post
- reconstruct (NBFConst _ : poss) repls = reconstruct poss repls
- return (NBFRecur (foldl' AppT (TupleT k) . reconstruct positions) indices)
+ | length args == k -> NBFTuple <$> traverse (parseField retpars) args
_ -> do
when (pprint astname `infixOf` pprint field) $
@@ -77,16 +64,23 @@ parseNBAST info = do
++ pprint astname ++ " in unsupported ways; ignoring those occurrences."
return (NBFConst field)
- let parseConstr con = do
- (vars, ctx, names, fields, retty) <- case con of
- ForallC vars ctx (GadtC names fields retty) -> return (vars, ctx, names, fields, retty)
- GadtC names fields retty -> return ([], [], names, fields, retty)
- _ -> fail "Unsupported constructors found"
+ let splitConstr (ForallC vars ctx (GadtC names fields retty))
+ | names'@(_:_) <- filter (`notElem` excludes) (map nameBase names) =
+ return (Just (vars, ctx, names', fields, retty))
+ | otherwise = return Nothing
+ splitConstr c@GadtC{} = splitConstr (ForallC [] [] c)
+ splitConstr c =
+ let names = case c of
+ NormalC n _ -> Just (show n)
+ _ -> Just (show c)
+ in fail $ "Unsupported constructors found" ++ maybe "" (\s -> ": " ++ show s) names
+
+ let parseConstr (vars, ctx, names, fields, retty) = do
(retpars, retindex) <- parseRetty astname (head names) retty
fields' <- traverse (\(ba, t) -> (ba,) <$> parseField retpars t) fields
- return (NBCon (map nameBase names) vars ctx fields' retpars retindex)
+ return (NBCon names vars ctx fields' retpars retindex)
- constrs' <- traverse parseConstr constrs
+ constrs' <- traverse parseConstr =<< catMaybes <$> traverse splitConstr constrs
return (NBAST (nameBase astname) params constrs')
@@ -100,48 +94,91 @@ defineBaseAST
-> Name -- ^ Name of the GADT to process
-> [Name] -- ^ Constructors to exclude (Var and Let, plus any other scoping construct)
-> (String -> String) -- ^ Constructor renaming function
+ -> String -- ^ Name of base -> nonbase conversion function to define
+ -> Name -- ^ Type of singleton types
+ -> (ExpQ -> ExpQ -> ExpQ -> ExpQ) -- ^ Lambda: typ (a -> b) -> typ a -> AST b -> AST (a -> b)
+ -> (ExpQ -> ExpQ -> ExpQ -> ExpQ) -- ^ Let: typ a -> String -> AST a -> AST b -> AST b
+ -> (ExpQ -> ExpQ -> ExpQ -> ExpQ) -- ^ Var: typ t -> String -> AST b -> AST b
-> Q [Dec]
-defineBaseAST basename astname excludes renameConstr = do
- info <- reify astname
- (params, constrs) <- case info of
- TyConI (DataD [] _ params Nothing constrs _) -> return (params, constrs)
- _ -> fail $ "Unsupported datatype: " ++ pprint astname
+defineBaseAST basename astname excludes renameConstr bnConvName typName mkLam mkLet mkVar = do
+ NBAST _ params constrs <- parseNBAST (map nameBase excludes) =<< reify astname
+
+ -- Build the data type
let basename' = mkName basename
- recvar = mkName "r"
+ conNameMap = Map.fromList [(nbname, mkName (renameConstr nbname))
+ | NBCon ns _ _ _ _ _ <- constrs, nbname <- ns]
- let detectRec :: Type -> Q (Maybe Type)
- detectRec field = do
- let (core, args) = splitApps field
- case core of
- ConT n
- | n == astname -> return (Just (VarT recvar `AppT` last args))
- ListT
- | [arg] <- args -> fmap (ListT `AppT`) <$> detectRec arg
- TupleT k
- | length args == k -> fmap (foldl' AppT (TupleT k)) . sequence <$> traverse detectRec args
- _ -> do
- when (pprint astname `infixOf` pprint field) $
- reportWarning $ "Field\n " ++ pprint field ++ "\nseems to refer to "
- ++ pprint astname ++ " in unsupported ways; ignoring those occurrences."
- return Nothing
-
- let processConstr con = do
- (vars, ctx, names, fields, retty) <- case con of
- ForallC vars ctx (GadtC names fields retty) -> return (vars, ctx, names, fields, retty)
- GadtC names fields retty -> return ([], [], names, fields, retty)
- _ -> fail "Unsupported constructors found"
- case filter (`notElem` excludes) names of
- [] -> return []
- names' -> do
- let names'' = map (mkName . renameConstr . nameBase) names'
- retty' <- fixRetty basename' astname (head names) recvar vars retty
- fields' <- traverse (\(ba, t) -> (ba,) . fromMaybe t <$> detectRec t) fields
- return [ForallC (map cleanupBndr vars ++ [PlainTV recvar SpecifiedSpec]) ctx (GadtC names'' fields' retty')]
+ let recvar = mkName "r"
+
+ let processField (NBFRecur idx) = VarT recvar `AppT` idx
+ processField (NBFConst t) = t
+ processField (NBFList f) = AppT ListT (processField f)
+ processField (NBFTuple fs) = foldl' AppT (TupleT (length fs)) (map processField fs)
+
+ let processConstr (NBCon names vars ctx fields retparams retindex) = do
+ let names' = map (conNameMap Map.!) names
+ let fields' = map (\(ba, f) -> (ba, processField f)) fields
+ let retty' = foldl' AppT (ConT basename') (map VarT retparams ++ [VarT recvar, retindex])
+ return [ForallC (map cleanupBndr vars ++ [PlainTV recvar SpecifiedSpec])
+ ctx (GadtC names' fields' retty')]
let params' = map cleanupBndr (init params ++ [PlainTV recvar (), last params])
constrs' <- concat <$> traverse processConstr constrs
- return [DataD [] (mkName basename) params' Nothing constrs' []]
+ let datatype = DataD [] (mkName basename) params' Nothing constrs' []
+
+ -- Build the B->N conversion function
+
+ let bnConvName' = mkName bnConvName
+ envparam = VarT (mkName "env")
+ tparam = VarT (mkName "t")
+ ftype = foldl' AppT (ConT basename') (map (VarT . bndrName) (init params))
+ arrow a b = ArrowT `AppT` a `AppT` b
+
+ let bnConvSig =
+ SigD bnConvName' $
+ (ConT ''BExpr `AppT` ConT typName `AppT` envparam `AppT` ftype `AppT` tparam)
+ `arrow`
+ foldl' AppT (ConT astname) (map (VarT . bndrName) (init params) ++ [tparam])
+
+ let clause' pats ex = Clause pats (NormalB ex) []
+ -- backConMap = Map.fromList [(renameConstr nbname, nbname)
+ -- | NBCon ns _ _ _ _ _ <- constrs, nbname <- ns]
+ reconstructField (NBFRecur _) = do
+ r <- newName "r"
+ return (VarP r, VarE bnConvName' `AppE` VarE r)
+ reconstructField (NBFConst _) = do
+ x <- newName "x"
+ return (VarP x, VarE x)
+ reconstructField (NBFList f) = do
+ (pat, ex) <- reconstructField f
+ l <- newName "l"
+ return (VarP l, VarE 'map `AppE` LamE [pat] ex `AppE` VarE l)
+ reconstructField (NBFTuple fs) = do
+ (pats, exps) <- unzip <$> traverse reconstructField fs
+ return (TupP pats, TupE (map Just exps))
+
+ let tyarg1 = mkName "ty1"
+ tyarg2 = mkName "ty2"
+ astarg = mkName "ast"
+ mkseq a b = VarE 'seq `AppE` a `AppE` b
+ infixr `mkseq`
+
+ bnConvFun <-
+ fmap (FunD bnConvName') . sequence $
+ [do (pats, exps) <- unzip <$> traverse (reconstructField . snd) fields
+ return $
+ clause' [ConP 'BOp [] [WildP, ConP (conNameMap Map.! nbname) [] pats]]
+ (foldl' AppE (ConE (mkName nbname)) exps)
+ | NBCon names _ _ fields _ _ <- constrs
+ , nbname <- names]
+ ++
+ [do body <- mkLam (varE tyarg1) (varE tyarg2) (varE astarg)
+ return $ clause' [ConP 'BLam [] [VarP tyarg1, VarP tyarg2, VarP astarg]]
+ -- put them in a useless seq so that they're not unused
+ (TupE [Just (VarE tyarg1), Just (VarE tyarg2), Just (VarE astarg)] `mkseq` body)]
+
+ return [datatype, bnConvSig, bnConvFun]
-- | Remove `:: Type` annotations because they unnecessarily require the user
-- to enable KindSignatures. Any other annotations we leave, in case the user
@@ -153,36 +190,24 @@ cleanupBndr (KindedTV name x k) | isType k = PlainTV name x
isType _ = False
cleanupBndr b = b
-parseRetty :: Name -> Name -> Type -> Q ([Name], Type)
+bndrName :: TyVarBndr a -> Name
+bndrName (PlainTV name _) = name
+bndrName (KindedTV name _ _) = name
+
+parseRetty :: Name -> String -> Type -> Q ([Name], Type)
parseRetty astname consname retty = do
case splitApps retty of
(ConT name, args)
- | name /= astname -> fail $ "Could not parse return type of constructor " ++ pprint consname
+ | name /= astname -> fail $ "Could not parse return type of constructor " ++ consname
| null args -> fail "Expected GADT to have type parameters"
| Just varnames <- traverse (\case VarT varname -> Just varname ; _ -> Nothing) (init args)
, allDistinct varnames ->
- return (init varnames, last args)
-
- | otherwise -> fail $ "All type parameters but the last one must be uniform over all constructors. "
- ++ "(Return type of constructor " ++ pprint consname ++ ")"
- _ -> fail $ "Could not parse return type of constructor " ++ pprint consname
-
-fixRetty :: Name -> Name -> Name -> Name -> [TyVarBndr a] -> Type -> Q Type
-fixRetty basename astname consname recvar vars retty = do
- case splitApps retty of
- (ConT name, args)
- | name /= astname -> fail $ "Could not parse return type of constructor " ++ pprint consname
- | null args -> fail "Expected GADT to have type parameters"
-
- | Just varnames <- traverse (\case VarT varname -> Just varname ; _ -> Nothing) (init args)
- , allDistinct varnames
- , all (`elem` map bndrName vars) varnames ->
- return (foldl' AppT (ConT basename) (init args ++ [VarT recvar, last args]))
+ return (varnames, last args)
| otherwise -> fail $ "All type parameters but the last one must be uniform over all constructors. "
- ++ "(Return type of constructor " ++ pprint consname ++ ")"
- _ -> fail $ "Could not parse return type of constructor " ++ pprint consname
+ ++ "(Return type of constructor " ++ consname ++ ")"
+ _ -> fail $ "Could not parse return type of constructor " ++ consname
splitApps :: Type -> (Type, [Type])
splitApps = flip go []
@@ -195,10 +220,6 @@ allDistinct l =
let sorted = sort l
in all (uncurry (/=)) (zip sorted (drop 1 sorted))
-bndrName :: TyVarBndr a -> Name
-bndrName (PlainTV n _) = n
-bndrName (KindedTV n _ _) = n
-
infixOf :: Eq a => [a] -> [a] -> Bool
short `infixOf` long = any (`startsWith` short) (tails long)
where a `startsWith` b = take (length b) a == b