diff options
Diffstat (limited to 'src/Language/Haskell/TH/HashableGADT.hs')
-rw-r--r-- | src/Language/Haskell/TH/HashableGADT.hs | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/src/Language/Haskell/TH/HashableGADT.hs b/src/Language/Haskell/TH/HashableGADT.hs new file mode 100644 index 0000000..6d638d0 --- /dev/null +++ b/src/Language/Haskell/TH/HashableGADT.hs @@ -0,0 +1,46 @@ +{-# LANGUAGE TemplateHaskellQuotes #-} +module Language.Haskell.TH.HashableGADT ( + deriveHashable, +) where + +import Control.Category ((>>>)) +import Control.Monad (forM) +import Data.Hashable +import Language.Haskell.TH + + +-- | The predicate indicates whether a particular field is to be included. If +-- all fields are to be hashed, return @True@. It gets passed the name of the +-- constructor, the index of the field, and the type of the field. +deriveHashable :: (Name -> Int -> Type -> Bool) -> Name -> Q [Dec] +deriveHashable includepred dataname = do + info <- reify dataname + (params, cons) <- case info of + TyConI (DataD [] _ params _ cons _) -> return (params, cons) + _ -> fail "deriveHashable: only data types supported" + + saltVar <- newName "s" + clauses <- concat <$> mapM (processCon saltVar includepred) cons + + let paramVars = map (VarT . bndrName) params + headType = foldl' AppT (ConT dataname) paramVars + return [InstanceD Nothing [] (ConT ''Hashable `AppT` headType) + [FunD 'hashWithSalt clauses]] + +processCon :: Name -> (Name -> Int -> Type -> Bool) -> Con -> Q [Clause] +processCon saltVar includepred constr = do + let thd (_,_,c) = c + let getFields (NormalC name fields) = return [(name, map snd fields)] + getFields (RecC name fields) = return [(name, map thd fields)] + getFields (InfixC t1 name t2) = return [(name, [snd t1, snd t2])] + getFields (GadtC names fields _) = return [(name, map snd fields) | name <- names] + getFields (RecGadtC names fields _) = return [(name, map thd fields) | name <- names] + getFields (ForallC _ _ con) = getFields con + + actualcons <- getFields constr + forM actualcons $ \(name, fields) -> do + _ + +bndrName :: TyVarBndr flag -> Name +bndrName (PlainTV n _) = n +bndrName (KindedTV n _ _) = n |