aboutsummaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/Arith/NonBase.hs5
-rw-r--r--test/NonBaseTH.hs68
2 files changed, 58 insertions, 15 deletions
diff --git a/test/Arith/NonBase.hs b/test/Arith/NonBase.hs
index 79c4428..ea27863 100644
--- a/test/Arith/NonBase.hs
+++ b/test/Arith/NonBase.hs
@@ -1,12 +1,13 @@
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TemplateHaskell #-}
module Arith.NonBase where
import Data.Kind
import Data.Type.Equality
--- import NonBaseTH
+import NonBaseTH
data Typ t where
@@ -43,3 +44,5 @@ data Arith t where
A_Pair :: Arith a -> Arith b -> Arith (a, b)
A_If :: Arith Bool -> Arith a -> Arith a -> Arith a
A_Mono :: Arith Bool -> Arith Bool
+
+defineBaseAST "ArithF" ''Arith ['A_Var, 'A_Let] (("AF_"++) . drop 2)
diff --git a/test/NonBaseTH.hs b/test/NonBaseTH.hs
index 712b680..f3c34f3 100644
--- a/test/NonBaseTH.hs
+++ b/test/NonBaseTH.hs
@@ -1,7 +1,11 @@
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TemplateHaskellQuotes #-}
module NonBaseTH where
-import Data.List (sort)
+import Control.Monad (when)
+import Data.List (sort, foldl', tails)
+import Data.Maybe (fromMaybe)
import Language.Haskell.TH
@@ -11,34 +15,65 @@ import Language.Haskell.TH
-- Remember to use 'lookupTypeName' or 'lookupValueName' instead of normal
-- quotes in case of punning of data types and constructors.
defineBaseAST
- :: Name -- ^ Name of the (base-functor-like) data type to define
+ :: String -- ^ Name of the (base-functor-like) data type to define
-> Name -- ^ Name of the GADT to process
- -> [Name] -- ^ Constructors to exclude (chiefly Var, Let, Lam)
+ -> [Name] -- ^ Constructors to exclude (Var and Let, plus any other scoping construct)
+ -> (String -> String) -- ^ Constructor renaming function
-> Q [Dec]
-defineBaseAST basename astname excludes = do
+defineBaseAST basename astname excludes renameConstr = do
info <- reify astname
(params, constrs) <- case info of
TyConI (DataD [] _ params Nothing constrs _) -> return (params, constrs)
_ -> fail $ "Unsupported datatype: " ++ pprint astname
- let recvar = mkName "r"
+ let basename' = mkName basename
+ recvar = mkName "r"
- let detectRec :: BangType -> Q (Maybe Type)
- detectRec (_, field) = _
+ let detectRec :: Type -> Q (Maybe Type)
+ detectRec field = do
+ let (core, args) = splitApps field
+ case core of
+ ConT n
+ | n == astname -> return (Just (VarT recvar `AppT` last args))
+ ListT
+ | [arg] <- args -> fmap (ListT `AppT`) <$> detectRec arg
+ TupleT k
+ | length args == k -> fmap (foldl' AppT (TupleT k)) . sequence <$> traverse detectRec args
+ _ -> do
+ when (pprint astname `infixOf` pprint field) $
+ reportWarning $ "Field\n " ++ pprint field ++ "\nseems to refer to "
+ ++ pprint astname ++ " in unsupported ways; ignoring those occurrences."
+ return Nothing
let processConstr con = do
(vars, ctx, names, fields, retty) <- case con of
ForallC vars ctx (GadtC names fields retty) -> return (vars, ctx, names, fields, retty)
GadtC names fields retty -> return ([], [], names, fields, retty)
_ -> fail "Unsupported constructors found"
- checkRetty astname (head names) vars retty
- _
+ case filter (`notElem` excludes) names of
+ [] -> return []
+ names' -> do
+ let names'' = map (mkName . renameConstr . nameBase) names'
+ retty' <- fixRetty basename' astname (head names) recvar vars retty
+ fields' <- traverse (\(ba, t) -> (ba,) . fromMaybe t <$> detectRec t) fields
+ return [ForallC (map cleanupBndr vars ++ [PlainTV recvar SpecifiedSpec]) ctx (GadtC names'' fields' retty')]
+ let params' = map cleanupBndr (init params ++ [PlainTV recvar (), last params])
constrs' <- concat <$> traverse processConstr constrs
- _
+ return [DataD [] (mkName basename) params' Nothing constrs' []]
-checkRetty :: Name -> Name -> [TyVarBndr a] -> Type -> Q ()
-checkRetty astname consname vars retty = do
+-- | Remove `:: Type` annotations because they unnecessarily require the user
+-- to enable KindSignatures. Any other annotations we leave, in case the user
+-- wrote them and they are necessary.
+cleanupBndr :: TyVarBndr a -> TyVarBndr a
+cleanupBndr (KindedTV name x k) | isType k = PlainTV name x
+ where isType StarT = True
+ isType (ConT n) | n == ''Type = True
+ isType _ = False
+cleanupBndr b = b
+
+fixRetty :: Name -> Name -> Name -> Name -> [TyVarBndr a] -> Type -> Q Type
+fixRetty basename astname consname recvar vars retty = do
case splitApps retty of
(ConT name, args)
| name /= astname -> fail $ "Could not parse return type of constructor " ++ pprint consname
@@ -47,7 +82,7 @@ checkRetty astname consname vars retty = do
| Just varnames <- traverse (\case VarT varname -> Just varname ; _ -> Nothing) (init args)
, allDistinct varnames
, all (`elem` map bndrName vars) varnames ->
- return ()
+ return (foldl' AppT (ConT basename) (init args ++ [VarT recvar, last args]))
| otherwise -> fail $ "All type parameters but the last one must be uniform over all constructors. "
++ "(Return type of constructor " ++ pprint consname ++ ")"
@@ -55,7 +90,8 @@ checkRetty astname consname vars retty = do
splitApps :: Type -> (Type, [Type])
splitApps = flip go []
- where go (AppT t arg) tl = go t (arg : tl)
+ where go (ParensT t) tl = go t tl
+ go (AppT t arg) tl = go t (arg : tl)
go t tl = (t, tl)
allDistinct :: Ord a => [a] -> Bool
@@ -66,3 +102,7 @@ allDistinct l =
bndrName :: TyVarBndr a -> Name
bndrName (PlainTV n _) = n
bndrName (KindedTV n _ _) = n
+
+infixOf :: Eq a => [a] -> [a] -> Bool
+short `infixOf` long = any (`startsWith` short) (tails long)
+ where a `startsWith` b = take (length b) a == b