{-# LANGUAGE LambdaCase #-} {-# LANGUAGE TemplateHaskellQuotes #-} {-# LANGUAGE TupleSections #-} module NonBaseTH ( defineBaseAST, ) where import Control.Monad (when) import Data.List (sort, foldl', tails) import qualified Data.Map.Strict as Map import Data.Maybe (catMaybes) import Language.Haskell.TH import Data.Expr.SharingRecovery -- | Non-base AST data type data NBAST = NBAST String [TyVarBndr ()] [NBCon] data NBCon = NBCon [String] -- ^ Names of the constructors defined with this clause [TyVarBndrSpec] -- ^ Type variables in scope in the clause Cxt -- ^ Constraints on the constructor type [(Bang, NBField Type)] -- ^ Constructor fields [Name] -- ^ All but the last type parameter in the result type Type -- ^ The last type parameter in the result type data NBField t = NBFList (NBField t) | NBFTuple [NBField t] | NBFRecur t -- ^ The last type parameter of this recursive position. (The -- other type parameters are fixed anyway.) | NBFConst Type parseNBAST :: [String] -> Info -> Q NBAST parseNBAST excludes info = do (astname, params, constrs) <- case info of TyConI (DataD [] astname params Nothing constrs _) -> return (astname, params, constrs) _ -> fail "Unsupported datatype" -- In this function we use parameter/index terminology: parameters are -- uniform, indices vary inside an AST. let parseField retpars field = do let (core, args) = splitApps field case core of ConT n | n == astname -> if not (null args) && init args == map VarT retpars then return (NBFRecur (last args)) else fail $ "Field\n " ++ pprint field ++ "\nis recursive, but with different type parameters than " ++ "the return type of this constructor. All but the last type " ++ "parameter of the GADT must be uniform over the entire AST." ListT | [arg] <- args -> NBFList <$> parseField retpars arg TupleT k | length args == k -> NBFTuple <$> traverse (parseField retpars) args _ -> do when (pprint astname `infixOf` pprint field) $ reportWarning $ "Field\n " ++ pprint field ++ "\nseems to refer to " ++ pprint astname ++ " in unsupported ways; ignoring those occurrences." return (NBFConst field) let splitConstr (ForallC vars ctx (GadtC names fields retty)) | names'@(_:_) <- filter (`notElem` excludes) (map nameBase names) = return (Just (vars, ctx, names', fields, retty)) | otherwise = return Nothing splitConstr c@GadtC{} = splitConstr (ForallC [] [] c) splitConstr c = let names = case c of NormalC n _ -> Just (show n) _ -> Just (show c) in fail $ "Unsupported constructors found" ++ maybe "" (\s -> ": " ++ show s) names let parseConstr (vars, ctx, names, fields, retty) = do (retpars, retindex) <- parseRetty astname (head names) retty fields' <- traverse (\(ba, t) -> (ba,) <$> parseField retpars t) fields return (NBCon names vars ctx fields' retpars retindex) constrs' <- traverse parseConstr =<< catMaybes <$> traverse splitConstr constrs return (NBAST (nameBase astname) params constrs') -- | Define a new GADT that is a base-functor-like version of a given existing -- GADT AST. -- -- Remember to use 'lookupTypeName' or 'lookupValueName' instead of normal -- quotes in case of punning of data types and constructors. defineBaseAST :: String -- ^ Name of the (base-functor-like) data type to define -> Name -- ^ Name of the GADT to process -> [Name] -- ^ Constructors to exclude (Var and Let, plus any other scoping construct) -> (String -> String) -- ^ Constructor renaming function -> String -- ^ Name of base -> nonbase conversion function to define -> Name -- ^ Type of singleton types -> (ExpQ -> ExpQ -> ExpQ -> ExpQ) -- ^ Lambda: typ (a -> b) -> typ a -> AST b -> AST (a -> b) -> (ExpQ -> ExpQ -> ExpQ -> ExpQ) -- ^ Let: typ a -> String -> AST a -> AST b -> AST b -> (ExpQ -> ExpQ -> ExpQ -> ExpQ) -- ^ Var: typ t -> String -> AST b -> AST b -> Q [Dec] defineBaseAST basename astname excludes renameConstr bnConvName typName mkLam mkLet mkVar = do NBAST _ params constrs <- parseNBAST (map nameBase excludes) =<< reify astname -- Build the data type let basename' = mkName basename conNameMap = Map.fromList [(nbname, mkName (renameConstr nbname)) | NBCon ns _ _ _ _ _ <- constrs, nbname <- ns] let recvar = mkName "r" let processField (NBFRecur idx) = VarT recvar `AppT` idx processField (NBFConst t) = t processField (NBFList f) = AppT ListT (processField f) processField (NBFTuple fs) = foldl' AppT (TupleT (length fs)) (map processField fs) let processConstr (NBCon names vars ctx fields retparams retindex) = do let names' = map (conNameMap Map.!) names let fields' = map (\(ba, f) -> (ba, processField f)) fields let retty' = foldl' AppT (ConT basename') (map VarT retparams ++ [VarT recvar, retindex]) return [ForallC (map cleanupBndr vars ++ [PlainTV recvar SpecifiedSpec]) ctx (GadtC names' fields' retty')] let params' = map cleanupBndr (init params ++ [PlainTV recvar (), last params]) constrs' <- concat <$> traverse processConstr constrs let datatype = DataD [] (mkName basename) params' Nothing constrs' [] -- Build the B->N conversion function let bnConvName' = mkName bnConvName envparam = VarT (mkName "env") tparam = VarT (mkName "t") ftype = foldl' AppT (ConT basename') (map (VarT . bndrName) (init params)) arrow a b = ArrowT `AppT` a `AppT` b let bnConvSig = SigD bnConvName' $ (ConT ''BExpr `AppT` ConT typName `AppT` envparam `AppT` ftype `AppT` tparam) `arrow` foldl' AppT (ConT astname) (map (VarT . bndrName) (init params) ++ [tparam]) let clause' pats ex = Clause pats (NormalB ex) [] -- backConMap = Map.fromList [(renameConstr nbname, nbname) -- | NBCon ns _ _ _ _ _ <- constrs, nbname <- ns] reconstructField (NBFRecur _) = do r <- newName "r" return (VarP r, VarE bnConvName' `AppE` VarE r) reconstructField (NBFConst _) = do x <- newName "x" return (VarP x, VarE x) reconstructField (NBFList f) = do (pat, ex) <- reconstructField f l <- newName "l" return (VarP l, VarE 'map `AppE` LamE [pat] ex `AppE` VarE l) reconstructField (NBFTuple fs) = do (pats, exps) <- unzip <$> traverse reconstructField fs return (TupP pats, TupE (map Just exps)) let tyarg1 = mkName "ty1" tyarg2 = mkName "ty2" astarg = mkName "ast" mkseq a b = VarE 'seq `AppE` a `AppE` b infixr `mkseq` bnConvFun <- fmap (FunD bnConvName') . sequence $ [do (pats, exps) <- unzip <$> traverse (reconstructField . snd) fields return $ clause' [ConP 'BOp [] [WildP, ConP (conNameMap Map.! nbname) [] pats]] (foldl' AppE (ConE (mkName nbname)) exps) | NBCon names _ _ fields _ _ <- constrs , nbname <- names] ++ [do body <- mkLam (varE tyarg1) (varE tyarg2) (varE astarg) return $ clause' [ConP 'BLam [] [VarP tyarg1, VarP tyarg2, VarP astarg]] -- put them in a useless seq so that they're not unused (TupE [Just (VarE tyarg1), Just (VarE tyarg2), Just (VarE astarg)] `mkseq` body)] return [datatype, bnConvSig, bnConvFun] -- | Remove `:: Type` annotations because they unnecessarily require the user -- to enable KindSignatures. Any other annotations we leave, in case the user -- wrote them and they are necessary. cleanupBndr :: TyVarBndr a -> TyVarBndr a cleanupBndr (KindedTV name x k) | isType k = PlainTV name x where isType StarT = True isType (ConT n) | n == ''Type = True isType _ = False cleanupBndr b = b bndrName :: TyVarBndr a -> Name bndrName (PlainTV name _) = name bndrName (KindedTV name _ _) = name parseRetty :: Name -> String -> Type -> Q ([Name], Type) parseRetty astname consname retty = do case splitApps retty of (ConT name, args) | name /= astname -> fail $ "Could not parse return type of constructor " ++ consname | null args -> fail "Expected GADT to have type parameters" | Just varnames <- traverse (\case VarT varname -> Just varname ; _ -> Nothing) (init args) , allDistinct varnames -> return (varnames, last args) | otherwise -> fail $ "All type parameters but the last one must be uniform over all constructors. " ++ "(Return type of constructor " ++ consname ++ ")" _ -> fail $ "Could not parse return type of constructor " ++ consname splitApps :: Type -> (Type, [Type]) splitApps = flip go [] where go (ParensT t) tl = go t tl go (AppT t arg) tl = go t (arg : tl) go t tl = (t, tl) allDistinct :: Ord a => [a] -> Bool allDistinct l = let sorted = sort l in all (uncurry (/=)) (zip sorted (drop 1 sorted)) infixOf :: Eq a => [a] -> [a] -> Bool short `infixOf` long = any (`startsWith` short) (tails long) where a `startsWith` b = take (length b) a == b