summaryrefslogtreecommitdiff
path: root/src/Language/Haskell/TH/HashableGADT.hs
blob: 6d638d0983c04b0ab5cc7684c3a22dc51acb6c04 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
{-# LANGUAGE TemplateHaskellQuotes #-}
module Language.Haskell.TH.HashableGADT (
  deriveHashable,
) where

import Control.Category ((>>>))
import Control.Monad (forM)
import Data.Hashable
import Language.Haskell.TH


-- | The predicate indicates whether a particular field is to be included. If
-- all fields are to be hashed, return @True@. It gets passed the name of the
-- constructor, the index of the field, and the type of the field.
deriveHashable :: (Name -> Int -> Type -> Bool) -> Name -> Q [Dec]
deriveHashable includepred dataname = do
  info <- reify dataname
  (params, cons) <- case info of
    TyConI (DataD [] _ params _ cons _) -> return (params, cons)
    _ -> fail "deriveHashable: only data types supported"

  saltVar <- newName "s"
  clauses <- concat <$> mapM (processCon saltVar includepred) cons

  let paramVars = map (VarT . bndrName) params
      headType = foldl' AppT (ConT dataname) paramVars
  return [InstanceD Nothing [] (ConT ''Hashable `AppT` headType)
            [FunD 'hashWithSalt clauses]]

processCon :: Name -> (Name -> Int -> Type -> Bool) -> Con -> Q [Clause]
processCon saltVar includepred constr = do
  let thd (_,_,c) = c
  let getFields (NormalC name fields) = return [(name, map snd fields)]
      getFields (RecC name fields) = return [(name, map thd fields)]
      getFields (InfixC t1 name t2) = return [(name, [snd t1, snd t2])]
      getFields (GadtC names fields _) = return [(name, map snd fields) | name <- names]
      getFields (RecGadtC names fields _) = return [(name, map thd fields) | name <- names]
      getFields (ForallC _ _ con) = getFields con

  actualcons <- getFields constr
  forM actualcons $ \(name, fields) -> do
    _

bndrName :: TyVarBndr flag -> Name
bndrName (PlainTV n _) = n
bndrName (KindedTV n _ _) = n