aboutsummaryrefslogtreecommitdiff
path: root/test/NonBaseTH.hs
blob: c6cf54b9897efe8b8b73ff6e9a2a567e1bab22ef (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
module NonBaseTH where

import Control.Monad (when)
import Data.List (sort, foldl', tails)
import Data.Maybe (fromMaybe)
import Language.Haskell.TH


-- | Non-base AST data type
data NBAST = NBAST String [TyVarBndr ()] [NBCon]

data NBCon = NBCon
    [String]           -- ^ Names of the constructors defined with this clause
    [TyVarBndrSpec]    -- ^ Type variables in scope in the clause
    Cxt                -- ^ Constraints on the constructor type
    [(Bang, NBField)]  -- ^ Constructor fields
    [Name]             -- ^ All but the last type parameter in the result type
    Type               -- ^ The last type parameter in the result type

data NBField
  = NBFRecur ([Type] -> Type)
                -- ^ Context: takes replacement recursive positions and wraps
                -- them in whatever structure they came in (e.g. a tuple).
             [Type]
                -- ^ The last type parameter of all the recursive positions in
                -- this field. (The other type parameters are fixed anyway.)
  | NBFConst Type

wrapNBField :: (Type -> Type) -> NBField -> NBField
wrapNBField f (NBFRecur ctx param) = NBFRecur (f . ctx) param
wrapNBField f (NBFConst t) = NBFConst (f t)

parseNBAST :: Info -> Q NBAST
parseNBAST info = do
  (astname, params, constrs) <- case info of
    TyConI (DataD [] astname params Nothing constrs _) -> return (astname, params, constrs)
    _ -> fail "Unsupported datatype"

  -- In this function we use parameter/index terminology: parameters are
  -- uniform, indices vary inside an AST.
  let parseField retpars field = do
        let (core, args) = splitApps field
        case core of
          ConT n
            | n == astname ->
                if not (null args) && init args == map VarT retpars
                  then return (NBFRecur head [last args])
                  else fail $ "Field\n  " ++ pprint field
                              ++ "\nis recursive, but with different type parameters than "
                              ++ "the return type of this constructor. All but the last type "
                              ++ "parameter of the GADT must be uniform over the entire AST."

          ListT
            | [arg] <- args -> wrapNBField (ListT `AppT`) <$> parseField retpars arg

          TupleT k
            | length args == k -> do
                positions <- traverse (parseField retpars) args
                case traverse (\case NBFConst t -> Just t ; _ -> Nothing) positions of
                  Just l -> return (NBFConst (foldl' AppT (TupleT k) l))
                  Nothing -> do
                    let indices = concatMap (\case NBFRecur _ is -> is ; NBFConst _ -> []) positions
                    let reconstruct [] [] = []
                        reconstruct [] _ = error "Invalid number of replacing recursive positions"
                        reconstruct (NBFRecur ctx pars : poss) repls =
                          let (pre, post) = splitAt (length pars) repls
                          in ctx pre : reconstruct poss post
                        reconstruct (NBFConst _ : poss) repls = reconstruct poss repls
                    return (NBFRecur (foldl' AppT (TupleT k) . reconstruct positions) indices)

          _ -> do
            when (pprint astname `infixOf` pprint field) $
              reportWarning $ "Field\n  " ++ pprint field ++ "\nseems to refer to "
                              ++ pprint astname ++ " in unsupported ways; ignoring those occurrences."
            return (NBFConst field)

  let parseConstr con = do
        (vars, ctx, names, fields, retty) <- case con of
          ForallC vars ctx (GadtC names fields retty) -> return (vars, ctx, names, fields, retty)
          GadtC names fields retty -> return ([], [], names, fields, retty)
          _ -> fail "Unsupported constructors found"
        (retpars, retindex) <- parseRetty astname (head names) retty
        fields' <- traverse (\(ba, t) -> (ba,) <$> parseField retpars t) fields
        return (NBCon (map nameBase names) vars ctx fields' retpars retindex)

  constrs' <- traverse parseConstr constrs
  return (NBAST (nameBase astname) params constrs')


-- | Define a new GADT that is a base-functor-like version of a given existing
-- GADT AST.
--
-- Remember to use 'lookupTypeName' or 'lookupValueName' instead of normal
-- quotes in case of punning of data types and constructors.
defineBaseAST
  :: String  -- ^ Name of the (base-functor-like) data type to define
  -> Name  -- ^ Name of the GADT to process
  -> [Name]  -- ^ Constructors to exclude (Var and Let, plus any other scoping construct)
  -> (String -> String)  -- ^ Constructor renaming function
  -> Q [Dec]
defineBaseAST basename astname excludes renameConstr = do
  info <- reify astname
  (params, constrs) <- case info of
    TyConI (DataD [] _ params Nothing constrs _) -> return (params, constrs)
    _ -> fail $ "Unsupported datatype: " ++ pprint astname

  let basename' = mkName basename
      recvar = mkName "r"

  let detectRec :: Type -> Q (Maybe Type)
      detectRec field = do
        let (core, args) = splitApps field
        case core of
          ConT n
            | n == astname -> return (Just (VarT recvar `AppT` last args))
          ListT
            | [arg] <- args -> fmap (ListT `AppT`) <$> detectRec arg
          TupleT k
            | length args == k -> fmap (foldl' AppT (TupleT k)) . sequence <$> traverse detectRec args
          _ -> do
            when (pprint astname `infixOf` pprint field) $
              reportWarning $ "Field\n  " ++ pprint field ++ "\nseems to refer to "
                              ++ pprint astname ++ " in unsupported ways; ignoring those occurrences."
            return Nothing

  let processConstr con = do
        (vars, ctx, names, fields, retty) <- case con of
          ForallC vars ctx (GadtC names fields retty) -> return (vars, ctx, names, fields, retty)
          GadtC names fields retty -> return ([], [], names, fields, retty)
          _ -> fail "Unsupported constructors found"
        case filter (`notElem` excludes) names of
          [] -> return []
          names' -> do
            let names'' = map (mkName . renameConstr . nameBase) names'
            retty' <- fixRetty basename' astname (head names) recvar vars retty
            fields' <- traverse (\(ba, t) -> (ba,) . fromMaybe t <$> detectRec t) fields
            return [ForallC (map cleanupBndr vars ++ [PlainTV recvar SpecifiedSpec]) ctx (GadtC names'' fields' retty')]

  let params' = map cleanupBndr (init params ++ [PlainTV recvar (), last params])
  constrs' <- concat <$> traverse processConstr constrs
  return [DataD [] (mkName basename) params' Nothing constrs' []]

-- | Remove `:: Type` annotations because they unnecessarily require the user
-- to enable KindSignatures. Any other annotations we leave, in case the user
-- wrote them and they are necessary.
cleanupBndr :: TyVarBndr a -> TyVarBndr a
cleanupBndr (KindedTV name x k) | isType k = PlainTV name x
  where isType StarT = True
        isType (ConT n) | n == ''Type = True
        isType _ = False
cleanupBndr b = b

parseRetty :: Name -> Name -> Type -> Q ([Name], Type)
parseRetty astname consname retty = do
  case splitApps retty of
    (ConT name, args)
      | name /= astname -> fail $ "Could not parse return type of constructor " ++ pprint consname
      | null args -> fail "Expected GADT to have type parameters"

      | Just varnames <- traverse (\case VarT varname -> Just varname ; _ -> Nothing) (init args)
      , allDistinct varnames ->
          return (init varnames, last args)

      | otherwise -> fail $ "All type parameters but the last one must be uniform over all constructors. "
                            ++ "(Return type of constructor " ++ pprint consname ++ ")"
    _ -> fail $ "Could not parse return type of constructor " ++ pprint consname

fixRetty :: Name -> Name -> Name -> Name -> [TyVarBndr a] -> Type -> Q Type
fixRetty basename astname consname recvar vars retty = do
  case splitApps retty of
    (ConT name, args)
      | name /= astname -> fail $ "Could not parse return type of constructor " ++ pprint consname
      | null args -> fail "Expected GADT to have type parameters"

      | Just varnames <- traverse (\case VarT varname -> Just varname ; _ -> Nothing) (init args)
      , allDistinct varnames
      , all (`elem` map bndrName vars) varnames ->
          return (foldl' AppT (ConT basename) (init args ++ [VarT recvar, last args]))

      | otherwise -> fail $ "All type parameters but the last one must be uniform over all constructors. "
                            ++ "(Return type of constructor " ++ pprint consname ++ ")"
    _ -> fail $ "Could not parse return type of constructor " ++ pprint consname

splitApps :: Type -> (Type, [Type])
splitApps = flip go []
  where go (ParensT t) tl = go t tl
        go (AppT t arg) tl = go t (arg : tl)
        go t tl = (t, tl)

allDistinct :: Ord a => [a] -> Bool
allDistinct l =
  let sorted = sort l
  in all (uncurry (/=)) (zip sorted (drop 1 sorted))

bndrName :: TyVarBndr a -> Name
bndrName (PlainTV n _) = n
bndrName (KindedTV n _ _) = n

infixOf :: Eq a => [a] -> [a] -> Bool
short `infixOf` long = any (`startsWith` short) (tails long)
  where a `startsWith` b = take (length b) a == b