{-# LANGUAGE LambdaCase #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TemplateHaskellQuotes #-} module NonBaseTH where import Control.Monad (when) import Data.List (sort, foldl', tails) import Data.Maybe (fromMaybe) import Language.Haskell.TH -- | Non-base AST data type data NBAST = NBAST String [TyVarBndr ()] [NBCon] data NBCon = NBCon [String] -- ^ Names of the constructors defined with this clause [TyVarBndrSpec] -- ^ Type variables in scope in the clause Cxt -- ^ Constraints on the constructor type [(Bang, NBField)] -- ^ Constructor fields [Name] -- ^ All but the last type parameter in the result type Type -- ^ The last type parameter in the result type data NBField = NBFRecur ([Type] -> Type) -- ^ Context: takes replacement recursive positions and wraps -- them in whatever structure they came in (e.g. a tuple). [Type] -- ^ The last type parameter of all the recursive positions in -- this field. (The other type parameters are fixed anyway.) | NBFConst Type wrapNBField :: (Type -> Type) -> NBField -> NBField wrapNBField f (NBFRecur ctx param) = NBFRecur (f . ctx) param wrapNBField f (NBFConst t) = NBFConst (f t) parseNBAST :: Info -> Q NBAST parseNBAST info = do (astname, params, constrs) <- case info of TyConI (DataD [] astname params Nothing constrs _) -> return (astname, params, constrs) _ -> fail "Unsupported datatype" -- In this function we use parameter/index terminology: parameters are -- uniform, indices vary inside an AST. let parseField retpars field = do let (core, args) = splitApps field case core of ConT n | n == astname -> if not (null args) && init args == map VarT retpars then return (NBFRecur head [last args]) else fail $ "Field\n " ++ pprint field ++ "\nis recursive, but with different type parameters than " ++ "the return type of this constructor. All but the last type " ++ "parameter of the GADT must be uniform over the entire AST." ListT | [arg] <- args -> wrapNBField (ListT `AppT`) <$> parseField retpars arg TupleT k | length args == k -> do positions <- traverse (parseField retpars) args case traverse (\case NBFConst t -> Just t ; _ -> Nothing) positions of Just l -> return (NBFConst (foldl' AppT (TupleT k) l)) Nothing -> do let indices = concatMap (\case NBFRecur _ is -> is ; NBFConst _ -> []) positions let reconstruct [] [] = [] reconstruct [] _ = error "Invalid number of replacing recursive positions" reconstruct (NBFRecur ctx pars : poss) repls = let (pre, post) = splitAt (length pars) repls in ctx pre : reconstruct poss post reconstruct (NBFConst _ : poss) repls = reconstruct poss repls return (NBFRecur (foldl' AppT (TupleT k) . reconstruct positions) indices) _ -> do when (pprint astname `infixOf` pprint field) $ reportWarning $ "Field\n " ++ pprint field ++ "\nseems to refer to " ++ pprint astname ++ " in unsupported ways; ignoring those occurrences." return (NBFConst field) let parseConstr con = do (vars, ctx, names, fields, retty) <- case con of ForallC vars ctx (GadtC names fields retty) -> return (vars, ctx, names, fields, retty) GadtC names fields retty -> return ([], [], names, fields, retty) _ -> fail "Unsupported constructors found" (retpars, retindex) <- parseRetty astname (head names) retty fields' <- traverse (\(ba, t) -> (ba,) <$> parseField retpars t) fields return (NBCon (map nameBase names) vars ctx fields' retpars retindex) constrs' <- traverse parseConstr constrs return (NBAST (nameBase astname) params constrs') -- | Define a new GADT that is a base-functor-like version of a given existing -- GADT AST. -- -- Remember to use 'lookupTypeName' or 'lookupValueName' instead of normal -- quotes in case of punning of data types and constructors. defineBaseAST :: String -- ^ Name of the (base-functor-like) data type to define -> Name -- ^ Name of the GADT to process -> [Name] -- ^ Constructors to exclude (Var and Let, plus any other scoping construct) -> (String -> String) -- ^ Constructor renaming function -> Q [Dec] defineBaseAST basename astname excludes renameConstr = do info <- reify astname (params, constrs) <- case info of TyConI (DataD [] _ params Nothing constrs _) -> return (params, constrs) _ -> fail $ "Unsupported datatype: " ++ pprint astname let basename' = mkName basename recvar = mkName "r" let detectRec :: Type -> Q (Maybe Type) detectRec field = do let (core, args) = splitApps field case core of ConT n | n == astname -> return (Just (VarT recvar `AppT` last args)) ListT | [arg] <- args -> fmap (ListT `AppT`) <$> detectRec arg TupleT k | length args == k -> fmap (foldl' AppT (TupleT k)) . sequence <$> traverse detectRec args _ -> do when (pprint astname `infixOf` pprint field) $ reportWarning $ "Field\n " ++ pprint field ++ "\nseems to refer to " ++ pprint astname ++ " in unsupported ways; ignoring those occurrences." return Nothing let processConstr con = do (vars, ctx, names, fields, retty) <- case con of ForallC vars ctx (GadtC names fields retty) -> return (vars, ctx, names, fields, retty) GadtC names fields retty -> return ([], [], names, fields, retty) _ -> fail "Unsupported constructors found" case filter (`notElem` excludes) names of [] -> return [] names' -> do let names'' = map (mkName . renameConstr . nameBase) names' retty' <- fixRetty basename' astname (head names) recvar vars retty fields' <- traverse (\(ba, t) -> (ba,) . fromMaybe t <$> detectRec t) fields return [ForallC (map cleanupBndr vars ++ [PlainTV recvar SpecifiedSpec]) ctx (GadtC names'' fields' retty')] let params' = map cleanupBndr (init params ++ [PlainTV recvar (), last params]) constrs' <- concat <$> traverse processConstr constrs return [DataD [] (mkName basename) params' Nothing constrs' []] -- | Remove `:: Type` annotations because they unnecessarily require the user -- to enable KindSignatures. Any other annotations we leave, in case the user -- wrote them and they are necessary. cleanupBndr :: TyVarBndr a -> TyVarBndr a cleanupBndr (KindedTV name x k) | isType k = PlainTV name x where isType StarT = True isType (ConT n) | n == ''Type = True isType _ = False cleanupBndr b = b parseRetty :: Name -> Name -> Type -> Q ([Name], Type) parseRetty astname consname retty = do case splitApps retty of (ConT name, args) | name /= astname -> fail $ "Could not parse return type of constructor " ++ pprint consname | null args -> fail "Expected GADT to have type parameters" | Just varnames <- traverse (\case VarT varname -> Just varname ; _ -> Nothing) (init args) , allDistinct varnames -> return (init varnames, last args) | otherwise -> fail $ "All type parameters but the last one must be uniform over all constructors. " ++ "(Return type of constructor " ++ pprint consname ++ ")" _ -> fail $ "Could not parse return type of constructor " ++ pprint consname fixRetty :: Name -> Name -> Name -> Name -> [TyVarBndr a] -> Type -> Q Type fixRetty basename astname consname recvar vars retty = do case splitApps retty of (ConT name, args) | name /= astname -> fail $ "Could not parse return type of constructor " ++ pprint consname | null args -> fail "Expected GADT to have type parameters" | Just varnames <- traverse (\case VarT varname -> Just varname ; _ -> Nothing) (init args) , allDistinct varnames , all (`elem` map bndrName vars) varnames -> return (foldl' AppT (ConT basename) (init args ++ [VarT recvar, last args])) | otherwise -> fail $ "All type parameters but the last one must be uniform over all constructors. " ++ "(Return type of constructor " ++ pprint consname ++ ")" _ -> fail $ "Could not parse return type of constructor " ++ pprint consname splitApps :: Type -> (Type, [Type]) splitApps = flip go [] where go (ParensT t) tl = go t tl go (AppT t arg) tl = go t (arg : tl) go t tl = (t, tl) allDistinct :: Ord a => [a] -> Bool allDistinct l = let sorted = sort l in all (uncurry (/=)) (zip sorted (drop 1 sorted)) bndrName :: TyVarBndr a -> Name bndrName (PlainTV n _) = n bndrName (KindedTV n _ _) = n infixOf :: Eq a => [a] -> [a] -> Bool short `infixOf` long = any (`startsWith` short) (tails long) where a `startsWith` b = take (length b) a == b