aboutsummaryrefslogtreecommitdiff
path: root/test/NonBaseTH.hs
blob: f3c34f38c570cd210f258f8fab55a7addf892cba (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
module NonBaseTH where

import Control.Monad (when)
import Data.List (sort, foldl', tails)
import Data.Maybe (fromMaybe)
import Language.Haskell.TH


-- | Define a new GADT that is a base-functor-like version of a given existing
-- GADT AST.
--
-- Remember to use 'lookupTypeName' or 'lookupValueName' instead of normal
-- quotes in case of punning of data types and constructors.
defineBaseAST
  :: String  -- ^ Name of the (base-functor-like) data type to define
  -> Name  -- ^ Name of the GADT to process
  -> [Name]  -- ^ Constructors to exclude (Var and Let, plus any other scoping construct)
  -> (String -> String)  -- ^ Constructor renaming function
  -> Q [Dec]
defineBaseAST basename astname excludes renameConstr = do
  info <- reify astname
  (params, constrs) <- case info of
    TyConI (DataD [] _ params Nothing constrs _) -> return (params, constrs)
    _ -> fail $ "Unsupported datatype: " ++ pprint astname

  let basename' = mkName basename
      recvar = mkName "r"

  let detectRec :: Type -> Q (Maybe Type)
      detectRec field = do
        let (core, args) = splitApps field
        case core of
          ConT n
            | n == astname -> return (Just (VarT recvar `AppT` last args))
          ListT
            | [arg] <- args -> fmap (ListT `AppT`) <$> detectRec arg
          TupleT k
            | length args == k -> fmap (foldl' AppT (TupleT k)) . sequence <$> traverse detectRec args
          _ -> do
            when (pprint astname `infixOf` pprint field) $
              reportWarning $ "Field\n  " ++ pprint field ++ "\nseems to refer to "
                              ++ pprint astname ++ " in unsupported ways; ignoring those occurrences."
            return Nothing

  let processConstr con = do
        (vars, ctx, names, fields, retty) <- case con of
          ForallC vars ctx (GadtC names fields retty) -> return (vars, ctx, names, fields, retty)
          GadtC names fields retty -> return ([], [], names, fields, retty)
          _ -> fail "Unsupported constructors found"
        case filter (`notElem` excludes) names of
          [] -> return []
          names' -> do
            let names'' = map (mkName . renameConstr . nameBase) names'
            retty' <- fixRetty basename' astname (head names) recvar vars retty
            fields' <- traverse (\(ba, t) -> (ba,) . fromMaybe t <$> detectRec t) fields
            return [ForallC (map cleanupBndr vars ++ [PlainTV recvar SpecifiedSpec]) ctx (GadtC names'' fields' retty')]

  let params' = map cleanupBndr (init params ++ [PlainTV recvar (), last params])
  constrs' <- concat <$> traverse processConstr constrs
  return [DataD [] (mkName basename) params' Nothing constrs' []]

-- | Remove `:: Type` annotations because they unnecessarily require the user
-- to enable KindSignatures. Any other annotations we leave, in case the user
-- wrote them and they are necessary.
cleanupBndr :: TyVarBndr a -> TyVarBndr a
cleanupBndr (KindedTV name x k) | isType k = PlainTV name x
  where isType StarT = True
        isType (ConT n) | n == ''Type = True
        isType _ = False
cleanupBndr b = b

fixRetty :: Name -> Name -> Name -> Name -> [TyVarBndr a] -> Type -> Q Type
fixRetty basename astname consname recvar vars retty = do
  case splitApps retty of
    (ConT name, args)
      | name /= astname -> fail $ "Could not parse return type of constructor " ++ pprint consname
      | null args -> fail "Expected GADT to have type parameters"

      | Just varnames <- traverse (\case VarT varname -> Just varname ; _ -> Nothing) (init args)
      , allDistinct varnames
      , all (`elem` map bndrName vars) varnames ->
          return (foldl' AppT (ConT basename) (init args ++ [VarT recvar, last args]))

      | otherwise -> fail $ "All type parameters but the last one must be uniform over all constructors. "
                            ++ "(Return type of constructor " ++ pprint consname ++ ")"
    _ -> fail $ "Could not parse return type of constructor " ++ pprint consname

splitApps :: Type -> (Type, [Type])
splitApps = flip go []
  where go (ParensT t) tl = go t tl
        go (AppT t arg) tl = go t (arg : tl)
        go t tl = (t, tl)

allDistinct :: Ord a => [a] -> Bool
allDistinct l =
  let sorted = sort l
  in all (uncurry (/=)) (zip sorted (drop 1 sorted))

bndrName :: TyVarBndr a -> Name
bndrName (PlainTV n _) = n
bndrName (KindedTV n _ _) = n

infixOf :: Eq a => [a] -> [a] -> Bool
short `infixOf` long = any (`startsWith` short) (tails long)
  where a `startsWith` b = take (length b) a == b