aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--test/NonBaseTH.hs96
1 files changed, 96 insertions, 0 deletions
diff --git a/test/NonBaseTH.hs b/test/NonBaseTH.hs
index f3c34f3..c6cf54b 100644
--- a/test/NonBaseTH.hs
+++ b/test/NonBaseTH.hs
@@ -9,6 +9,87 @@ import Data.Maybe (fromMaybe)
import Language.Haskell.TH
+-- | Non-base AST data type
+data NBAST = NBAST String [TyVarBndr ()] [NBCon]
+
+data NBCon = NBCon
+ [String] -- ^ Names of the constructors defined with this clause
+ [TyVarBndrSpec] -- ^ Type variables in scope in the clause
+ Cxt -- ^ Constraints on the constructor type
+ [(Bang, NBField)] -- ^ Constructor fields
+ [Name] -- ^ All but the last type parameter in the result type
+ Type -- ^ The last type parameter in the result type
+
+data NBField
+ = NBFRecur ([Type] -> Type)
+ -- ^ Context: takes replacement recursive positions and wraps
+ -- them in whatever structure they came in (e.g. a tuple).
+ [Type]
+ -- ^ The last type parameter of all the recursive positions in
+ -- this field. (The other type parameters are fixed anyway.)
+ | NBFConst Type
+
+wrapNBField :: (Type -> Type) -> NBField -> NBField
+wrapNBField f (NBFRecur ctx param) = NBFRecur (f . ctx) param
+wrapNBField f (NBFConst t) = NBFConst (f t)
+
+parseNBAST :: Info -> Q NBAST
+parseNBAST info = do
+ (astname, params, constrs) <- case info of
+ TyConI (DataD [] astname params Nothing constrs _) -> return (astname, params, constrs)
+ _ -> fail "Unsupported datatype"
+
+ -- In this function we use parameter/index terminology: parameters are
+ -- uniform, indices vary inside an AST.
+ let parseField retpars field = do
+ let (core, args) = splitApps field
+ case core of
+ ConT n
+ | n == astname ->
+ if not (null args) && init args == map VarT retpars
+ then return (NBFRecur head [last args])
+ else fail $ "Field\n " ++ pprint field
+ ++ "\nis recursive, but with different type parameters than "
+ ++ "the return type of this constructor. All but the last type "
+ ++ "parameter of the GADT must be uniform over the entire AST."
+
+ ListT
+ | [arg] <- args -> wrapNBField (ListT `AppT`) <$> parseField retpars arg
+
+ TupleT k
+ | length args == k -> do
+ positions <- traverse (parseField retpars) args
+ case traverse (\case NBFConst t -> Just t ; _ -> Nothing) positions of
+ Just l -> return (NBFConst (foldl' AppT (TupleT k) l))
+ Nothing -> do
+ let indices = concatMap (\case NBFRecur _ is -> is ; NBFConst _ -> []) positions
+ let reconstruct [] [] = []
+ reconstruct [] _ = error "Invalid number of replacing recursive positions"
+ reconstruct (NBFRecur ctx pars : poss) repls =
+ let (pre, post) = splitAt (length pars) repls
+ in ctx pre : reconstruct poss post
+ reconstruct (NBFConst _ : poss) repls = reconstruct poss repls
+ return (NBFRecur (foldl' AppT (TupleT k) . reconstruct positions) indices)
+
+ _ -> do
+ when (pprint astname `infixOf` pprint field) $
+ reportWarning $ "Field\n " ++ pprint field ++ "\nseems to refer to "
+ ++ pprint astname ++ " in unsupported ways; ignoring those occurrences."
+ return (NBFConst field)
+
+ let parseConstr con = do
+ (vars, ctx, names, fields, retty) <- case con of
+ ForallC vars ctx (GadtC names fields retty) -> return (vars, ctx, names, fields, retty)
+ GadtC names fields retty -> return ([], [], names, fields, retty)
+ _ -> fail "Unsupported constructors found"
+ (retpars, retindex) <- parseRetty astname (head names) retty
+ fields' <- traverse (\(ba, t) -> (ba,) <$> parseField retpars t) fields
+ return (NBCon (map nameBase names) vars ctx fields' retpars retindex)
+
+ constrs' <- traverse parseConstr constrs
+ return (NBAST (nameBase astname) params constrs')
+
+
-- | Define a new GADT that is a base-functor-like version of a given existing
-- GADT AST.
--
@@ -72,6 +153,21 @@ cleanupBndr (KindedTV name x k) | isType k = PlainTV name x
isType _ = False
cleanupBndr b = b
+parseRetty :: Name -> Name -> Type -> Q ([Name], Type)
+parseRetty astname consname retty = do
+ case splitApps retty of
+ (ConT name, args)
+ | name /= astname -> fail $ "Could not parse return type of constructor " ++ pprint consname
+ | null args -> fail "Expected GADT to have type parameters"
+
+ | Just varnames <- traverse (\case VarT varname -> Just varname ; _ -> Nothing) (init args)
+ , allDistinct varnames ->
+ return (init varnames, last args)
+
+ | otherwise -> fail $ "All type parameters but the last one must be uniform over all constructors. "
+ ++ "(Return type of constructor " ++ pprint consname ++ ")"
+ _ -> fail $ "Could not parse return type of constructor " ++ pprint consname
+
fixRetty :: Name -> Name -> Name -> Name -> [TyVarBndr a] -> Type -> Q Type
fixRetty basename astname consname recvar vars retty = do
case splitApps retty of