aboutsummaryrefslogtreecommitdiff
path: root/test-th/NonBaseTH.hs
diff options
context:
space:
mode:
Diffstat (limited to 'test-th/NonBaseTH.hs')
-rw-r--r--test-th/NonBaseTH.hs225
1 files changed, 225 insertions, 0 deletions
diff --git a/test-th/NonBaseTH.hs b/test-th/NonBaseTH.hs
new file mode 100644
index 0000000..4741ea0
--- /dev/null
+++ b/test-th/NonBaseTH.hs
@@ -0,0 +1,225 @@
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TemplateHaskellQuotes #-}
+{-# LANGUAGE TupleSections #-}
+module NonBaseTH (
+ defineBaseAST,
+) where
+
+import Control.Monad (when)
+import Data.List (sort, foldl', tails)
+import qualified Data.Map.Strict as Map
+import Data.Maybe (catMaybes)
+import Language.Haskell.TH
+
+import Data.Expr.SharingRecovery
+
+
+-- | Non-base AST data type
+data NBAST = NBAST String [TyVarBndr ()] [NBCon]
+
+data NBCon = NBCon
+ [String] -- ^ Names of the constructors defined with this clause
+ [TyVarBndrSpec] -- ^ Type variables in scope in the clause
+ Cxt -- ^ Constraints on the constructor type
+ [(Bang, NBField Type)] -- ^ Constructor fields
+ [Name] -- ^ All but the last type parameter in the result type
+ Type -- ^ The last type parameter in the result type
+
+data NBField t
+ = NBFList (NBField t)
+ | NBFTuple [NBField t]
+ | NBFRecur t -- ^ The last type parameter of this recursive position. (The
+ -- other type parameters are fixed anyway.)
+ | NBFConst Type
+
+parseNBAST :: [String] -> Info -> Q NBAST
+parseNBAST excludes info = do
+ (astname, params, constrs) <- case info of
+ TyConI (DataD [] astname params Nothing constrs _) -> return (astname, params, constrs)
+ _ -> fail "Unsupported datatype"
+
+ -- In this function we use parameter/index terminology: parameters are
+ -- uniform, indices vary inside an AST.
+ let parseField retpars field = do
+ let (core, args) = splitApps field
+ case core of
+ ConT n
+ | n == astname ->
+ if not (null args) && init args == map VarT retpars
+ then return (NBFRecur (last args))
+ else fail $ "Field\n " ++ pprint field
+ ++ "\nis recursive, but with different type parameters than "
+ ++ "the return type of this constructor. All but the last type "
+ ++ "parameter of the GADT must be uniform over the entire AST."
+
+ ListT
+ | [arg] <- args -> NBFList <$> parseField retpars arg
+
+ TupleT k
+ | length args == k -> NBFTuple <$> traverse (parseField retpars) args
+
+ _ -> do
+ when (pprint astname `infixOf` pprint field) $
+ reportWarning $ "Field\n " ++ pprint field ++ "\nseems to refer to "
+ ++ pprint astname ++ " in unsupported ways; ignoring those occurrences."
+ return (NBFConst field)
+
+ let splitConstr (ForallC vars ctx (GadtC names fields retty))
+ | names'@(_:_) <- filter (`notElem` excludes) (map nameBase names) =
+ return (Just (vars, ctx, names', fields, retty))
+ | otherwise = return Nothing
+ splitConstr c@GadtC{} = splitConstr (ForallC [] [] c)
+ splitConstr c =
+ let names = case c of
+ NormalC n _ -> Just (show n)
+ _ -> Just (show c)
+ in fail $ "Unsupported constructors found" ++ maybe "" (\s -> ": " ++ show s) names
+
+ let parseConstr (vars, ctx, names, fields, retty) = do
+ (retpars, retindex) <- parseRetty astname (head names) retty
+ fields' <- traverse (\(ba, t) -> (ba,) <$> parseField retpars t) fields
+ return (NBCon names vars ctx fields' retpars retindex)
+
+ constrs' <- traverse parseConstr =<< catMaybes <$> traverse splitConstr constrs
+ return (NBAST (nameBase astname) params constrs')
+
+
+-- | Define a new GADT that is a base-functor-like version of a given existing
+-- GADT AST.
+--
+-- Remember to use 'lookupTypeName' or 'lookupValueName' instead of normal
+-- quotes in case of punning of data types and constructors.
+defineBaseAST
+ :: String -- ^ Name of the (base-functor-like) data type to define
+ -> Name -- ^ Name of the GADT to process
+ -> [Name] -- ^ Constructors to exclude (Var and Let, plus any other scoping construct)
+ -> (String -> String) -- ^ Constructor renaming function
+ -> String -- ^ Name of base -> nonbase conversion function to define
+ -> Name -- ^ Type of singleton types
+ -> (ExpQ -> ExpQ -> ExpQ -> ExpQ) -- ^ Lambda: typ (a -> b) -> typ a -> AST b -> AST (a -> b)
+ -> (ExpQ -> ExpQ -> ExpQ -> ExpQ) -- ^ Let: typ a -> String -> AST a -> AST b -> AST b
+ -> (ExpQ -> ExpQ -> ExpQ -> ExpQ) -- ^ Var: typ t -> String -> AST b -> AST b
+ -> Q [Dec]
+defineBaseAST basename astname excludes renameConstr bnConvName typName mkLam mkLet mkVar = do
+ NBAST _ params constrs <- parseNBAST (map nameBase excludes) =<< reify astname
+
+ -- Build the data type
+
+ let basename' = mkName basename
+ conNameMap = Map.fromList [(nbname, mkName (renameConstr nbname))
+ | NBCon ns _ _ _ _ _ <- constrs, nbname <- ns]
+
+ let recvar = mkName "r"
+
+ let processField (NBFRecur idx) = VarT recvar `AppT` idx
+ processField (NBFConst t) = t
+ processField (NBFList f) = AppT ListT (processField f)
+ processField (NBFTuple fs) = foldl' AppT (TupleT (length fs)) (map processField fs)
+
+ let processConstr (NBCon names vars ctx fields retparams retindex) = do
+ let names' = map (conNameMap Map.!) names
+ let fields' = map (\(ba, f) -> (ba, processField f)) fields
+ let retty' = foldl' AppT (ConT basename') (map VarT retparams ++ [VarT recvar, retindex])
+ return [ForallC (map cleanupBndr vars ++ [PlainTV recvar SpecifiedSpec])
+ ctx (GadtC names' fields' retty')]
+
+ let params' = map cleanupBndr (init params ++ [PlainTV recvar (), last params])
+ constrs' <- concat <$> traverse processConstr constrs
+ let datatype = DataD [] (mkName basename) params' Nothing constrs' []
+
+ -- Build the B->N conversion function
+
+ let bnConvName' = mkName bnConvName
+ envparam = VarT (mkName "env")
+ tparam = VarT (mkName "t")
+ ftype = foldl' AppT (ConT basename') (map (VarT . bndrName) (init params))
+ arrow a b = ArrowT `AppT` a `AppT` b
+
+ let bnConvSig =
+ SigD bnConvName' $
+ (ConT ''BExpr `AppT` ConT typName `AppT` envparam `AppT` ftype `AppT` tparam)
+ `arrow`
+ foldl' AppT (ConT astname) (map (VarT . bndrName) (init params) ++ [tparam])
+
+ let clause' pats ex = Clause pats (NormalB ex) []
+ -- backConMap = Map.fromList [(renameConstr nbname, nbname)
+ -- | NBCon ns _ _ _ _ _ <- constrs, nbname <- ns]
+ reconstructField (NBFRecur _) = do
+ r <- newName "r"
+ return (VarP r, VarE bnConvName' `AppE` VarE r)
+ reconstructField (NBFConst _) = do
+ x <- newName "x"
+ return (VarP x, VarE x)
+ reconstructField (NBFList f) = do
+ (pat, ex) <- reconstructField f
+ l <- newName "l"
+ return (VarP l, VarE 'map `AppE` LamE [pat] ex `AppE` VarE l)
+ reconstructField (NBFTuple fs) = do
+ (pats, exps) <- unzip <$> traverse reconstructField fs
+ return (TupP pats, TupE (map Just exps))
+
+ let tyarg1 = mkName "ty1"
+ tyarg2 = mkName "ty2"
+ astarg = mkName "ast"
+ mkseq a b = VarE 'seq `AppE` a `AppE` b
+ infixr `mkseq`
+
+ bnConvFun <-
+ fmap (FunD bnConvName') . sequence $
+ [do (pats, exps) <- unzip <$> traverse (reconstructField . snd) fields
+ return $
+ clause' [ConP 'BOp [] [WildP, ConP (conNameMap Map.! nbname) [] pats]]
+ (foldl' AppE (ConE (mkName nbname)) exps)
+ | NBCon names _ _ fields _ _ <- constrs
+ , nbname <- names]
+ ++
+ [do body <- mkLam (varE tyarg1) (varE tyarg2) (varE astarg)
+ return $ clause' [ConP 'BLam [] [VarP tyarg1, VarP tyarg2, VarP astarg]]
+ -- put them in a useless seq so that they're not unused
+ (TupE [Just (VarE tyarg1), Just (VarE tyarg2), Just (VarE astarg)] `mkseq` body)]
+
+ return [datatype, bnConvSig, bnConvFun]
+
+-- | Remove `:: Type` annotations because they unnecessarily require the user
+-- to enable KindSignatures. Any other annotations we leave, in case the user
+-- wrote them and they are necessary.
+cleanupBndr :: TyVarBndr a -> TyVarBndr a
+cleanupBndr (KindedTV name x k) | isType k = PlainTV name x
+ where isType StarT = True
+ isType (ConT n) | n == ''Type = True
+ isType _ = False
+cleanupBndr b = b
+
+bndrName :: TyVarBndr a -> Name
+bndrName (PlainTV name _) = name
+bndrName (KindedTV name _ _) = name
+
+parseRetty :: Name -> String -> Type -> Q ([Name], Type)
+parseRetty astname consname retty = do
+ case splitApps retty of
+ (ConT name, args)
+ | name /= astname -> fail $ "Could not parse return type of constructor " ++ consname
+ | null args -> fail "Expected GADT to have type parameters"
+
+ | Just varnames <- traverse (\case VarT varname -> Just varname ; _ -> Nothing) (init args)
+ , allDistinct varnames ->
+ return (varnames, last args)
+
+ | otherwise -> fail $ "All type parameters but the last one must be uniform over all constructors. "
+ ++ "(Return type of constructor " ++ consname ++ ")"
+ _ -> fail $ "Could not parse return type of constructor " ++ consname
+
+splitApps :: Type -> (Type, [Type])
+splitApps = flip go []
+ where go (ParensT t) tl = go t tl
+ go (AppT t arg) tl = go t (arg : tl)
+ go t tl = (t, tl)
+
+allDistinct :: Ord a => [a] -> Bool
+allDistinct l =
+ let sorted = sort l
+ in all (uncurry (/=)) (zip sorted (drop 1 sorted))
+
+infixOf :: Eq a => [a] -> [a] -> Bool
+short `infixOf` long = any (`startsWith` short) (tails long)
+ where a `startsWith` b = take (length b) a == b