diff options
-rw-r--r-- | test/NonBaseTH.hs | 96 |
1 files changed, 96 insertions, 0 deletions
diff --git a/test/NonBaseTH.hs b/test/NonBaseTH.hs index f3c34f3..c6cf54b 100644 --- a/test/NonBaseTH.hs +++ b/test/NonBaseTH.hs @@ -9,6 +9,87 @@ import Data.Maybe (fromMaybe) import Language.Haskell.TH +-- | Non-base AST data type +data NBAST = NBAST String [TyVarBndr ()] [NBCon] + +data NBCon = NBCon + [String] -- ^ Names of the constructors defined with this clause + [TyVarBndrSpec] -- ^ Type variables in scope in the clause + Cxt -- ^ Constraints on the constructor type + [(Bang, NBField)] -- ^ Constructor fields + [Name] -- ^ All but the last type parameter in the result type + Type -- ^ The last type parameter in the result type + +data NBField + = NBFRecur ([Type] -> Type) + -- ^ Context: takes replacement recursive positions and wraps + -- them in whatever structure they came in (e.g. a tuple). + [Type] + -- ^ The last type parameter of all the recursive positions in + -- this field. (The other type parameters are fixed anyway.) + | NBFConst Type + +wrapNBField :: (Type -> Type) -> NBField -> NBField +wrapNBField f (NBFRecur ctx param) = NBFRecur (f . ctx) param +wrapNBField f (NBFConst t) = NBFConst (f t) + +parseNBAST :: Info -> Q NBAST +parseNBAST info = do + (astname, params, constrs) <- case info of + TyConI (DataD [] astname params Nothing constrs _) -> return (astname, params, constrs) + _ -> fail "Unsupported datatype" + + -- In this function we use parameter/index terminology: parameters are + -- uniform, indices vary inside an AST. + let parseField retpars field = do + let (core, args) = splitApps field + case core of + ConT n + | n == astname -> + if not (null args) && init args == map VarT retpars + then return (NBFRecur head [last args]) + else fail $ "Field\n " ++ pprint field + ++ "\nis recursive, but with different type parameters than " + ++ "the return type of this constructor. All but the last type " + ++ "parameter of the GADT must be uniform over the entire AST." + + ListT + | [arg] <- args -> wrapNBField (ListT `AppT`) <$> parseField retpars arg + + TupleT k + | length args == k -> do + positions <- traverse (parseField retpars) args + case traverse (\case NBFConst t -> Just t ; _ -> Nothing) positions of + Just l -> return (NBFConst (foldl' AppT (TupleT k) l)) + Nothing -> do + let indices = concatMap (\case NBFRecur _ is -> is ; NBFConst _ -> []) positions + let reconstruct [] [] = [] + reconstruct [] _ = error "Invalid number of replacing recursive positions" + reconstruct (NBFRecur ctx pars : poss) repls = + let (pre, post) = splitAt (length pars) repls + in ctx pre : reconstruct poss post + reconstruct (NBFConst _ : poss) repls = reconstruct poss repls + return (NBFRecur (foldl' AppT (TupleT k) . reconstruct positions) indices) + + _ -> do + when (pprint astname `infixOf` pprint field) $ + reportWarning $ "Field\n " ++ pprint field ++ "\nseems to refer to " + ++ pprint astname ++ " in unsupported ways; ignoring those occurrences." + return (NBFConst field) + + let parseConstr con = do + (vars, ctx, names, fields, retty) <- case con of + ForallC vars ctx (GadtC names fields retty) -> return (vars, ctx, names, fields, retty) + GadtC names fields retty -> return ([], [], names, fields, retty) + _ -> fail "Unsupported constructors found" + (retpars, retindex) <- parseRetty astname (head names) retty + fields' <- traverse (\(ba, t) -> (ba,) <$> parseField retpars t) fields + return (NBCon (map nameBase names) vars ctx fields' retpars retindex) + + constrs' <- traverse parseConstr constrs + return (NBAST (nameBase astname) params constrs') + + -- | Define a new GADT that is a base-functor-like version of a given existing -- GADT AST. -- @@ -72,6 +153,21 @@ cleanupBndr (KindedTV name x k) | isType k = PlainTV name x isType _ = False cleanupBndr b = b +parseRetty :: Name -> Name -> Type -> Q ([Name], Type) +parseRetty astname consname retty = do + case splitApps retty of + (ConT name, args) + | name /= astname -> fail $ "Could not parse return type of constructor " ++ pprint consname + | null args -> fail "Expected GADT to have type parameters" + + | Just varnames <- traverse (\case VarT varname -> Just varname ; _ -> Nothing) (init args) + , allDistinct varnames -> + return (init varnames, last args) + + | otherwise -> fail $ "All type parameters but the last one must be uniform over all constructors. " + ++ "(Return type of constructor " ++ pprint consname ++ ")" + _ -> fail $ "Could not parse return type of constructor " ++ pprint consname + fixRetty :: Name -> Name -> Name -> Name -> [TyVarBndr a] -> Type -> Q Type fixRetty basename astname consname recvar vars retty = do case splitApps retty of |