diff options
Diffstat (limited to 'test')
| -rw-r--r-- | test/Arith/NonBase.hs | 5 | ||||
| -rw-r--r-- | test/NonBaseTH.hs | 68 | 
2 files changed, 58 insertions, 15 deletions
diff --git a/test/Arith/NonBase.hs b/test/Arith/NonBase.hs index 79c4428..ea27863 100644 --- a/test/Arith/NonBase.hs +++ b/test/Arith/NonBase.hs @@ -1,12 +1,13 @@  {-# LANGUAGE StandaloneDeriving #-}  {-# LANGUAGE GADTs #-}  {-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TemplateHaskell #-}  module Arith.NonBase where  import Data.Kind  import Data.Type.Equality --- import NonBaseTH +import NonBaseTH  data Typ t where @@ -43,3 +44,5 @@ data Arith t where    A_Pair :: Arith a -> Arith b -> Arith (a, b)    A_If :: Arith Bool -> Arith a -> Arith a -> Arith a    A_Mono :: Arith Bool -> Arith Bool + +defineBaseAST "ArithF" ''Arith ['A_Var, 'A_Let] (("AF_"++) . drop 2) diff --git a/test/NonBaseTH.hs b/test/NonBaseTH.hs index 712b680..f3c34f3 100644 --- a/test/NonBaseTH.hs +++ b/test/NonBaseTH.hs @@ -1,7 +1,11 @@  {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TemplateHaskellQuotes #-}  module NonBaseTH where -import Data.List (sort) +import Control.Monad (when) +import Data.List (sort, foldl', tails) +import Data.Maybe (fromMaybe)  import Language.Haskell.TH @@ -11,34 +15,65 @@ import Language.Haskell.TH  -- Remember to use 'lookupTypeName' or 'lookupValueName' instead of normal  -- quotes in case of punning of data types and constructors.  defineBaseAST -  :: Name  -- ^ Name of the (base-functor-like) data type to define +  :: String  -- ^ Name of the (base-functor-like) data type to define    -> Name  -- ^ Name of the GADT to process -  -> [Name]  -- ^ Constructors to exclude (chiefly Var, Let, Lam) +  -> [Name]  -- ^ Constructors to exclude (Var and Let, plus any other scoping construct) +  -> (String -> String)  -- ^ Constructor renaming function    -> Q [Dec] -defineBaseAST basename astname excludes = do +defineBaseAST basename astname excludes renameConstr = do    info <- reify astname    (params, constrs) <- case info of      TyConI (DataD [] _ params Nothing constrs _) -> return (params, constrs)      _ -> fail $ "Unsupported datatype: " ++ pprint astname -  let recvar = mkName "r" +  let basename' = mkName basename +      recvar = mkName "r" -  let detectRec :: BangType -> Q (Maybe Type) -      detectRec (_, field) = _ +  let detectRec :: Type -> Q (Maybe Type) +      detectRec field = do +        let (core, args) = splitApps field +        case core of +          ConT n +            | n == astname -> return (Just (VarT recvar `AppT` last args)) +          ListT +            | [arg] <- args -> fmap (ListT `AppT`) <$> detectRec arg +          TupleT k +            | length args == k -> fmap (foldl' AppT (TupleT k)) . sequence <$> traverse detectRec args +          _ -> do +            when (pprint astname `infixOf` pprint field) $ +              reportWarning $ "Field\n  " ++ pprint field ++ "\nseems to refer to " +                              ++ pprint astname ++ " in unsupported ways; ignoring those occurrences." +            return Nothing    let processConstr con = do          (vars, ctx, names, fields, retty) <- case con of            ForallC vars ctx (GadtC names fields retty) -> return (vars, ctx, names, fields, retty)            GadtC names fields retty -> return ([], [], names, fields, retty)            _ -> fail "Unsupported constructors found" -        checkRetty astname (head names) vars retty -        _ +        case filter (`notElem` excludes) names of +          [] -> return [] +          names' -> do +            let names'' = map (mkName . renameConstr . nameBase) names' +            retty' <- fixRetty basename' astname (head names) recvar vars retty +            fields' <- traverse (\(ba, t) -> (ba,) . fromMaybe t <$> detectRec t) fields +            return [ForallC (map cleanupBndr vars ++ [PlainTV recvar SpecifiedSpec]) ctx (GadtC names'' fields' retty')] +  let params' = map cleanupBndr (init params ++ [PlainTV recvar (), last params])    constrs' <- concat <$> traverse processConstr constrs -  _ +  return [DataD [] (mkName basename) params' Nothing constrs' []] -checkRetty :: Name -> Name -> [TyVarBndr a] -> Type -> Q () -checkRetty astname consname vars retty = do +-- | Remove `:: Type` annotations because they unnecessarily require the user +-- to enable KindSignatures. Any other annotations we leave, in case the user +-- wrote them and they are necessary. +cleanupBndr :: TyVarBndr a -> TyVarBndr a +cleanupBndr (KindedTV name x k) | isType k = PlainTV name x +  where isType StarT = True +        isType (ConT n) | n == ''Type = True +        isType _ = False +cleanupBndr b = b + +fixRetty :: Name -> Name -> Name -> Name -> [TyVarBndr a] -> Type -> Q Type +fixRetty basename astname consname recvar vars retty = do    case splitApps retty of      (ConT name, args)        | name /= astname -> fail $ "Could not parse return type of constructor " ++ pprint consname @@ -47,7 +82,7 @@ checkRetty astname consname vars retty = do        | Just varnames <- traverse (\case VarT varname -> Just varname ; _ -> Nothing) (init args)        , allDistinct varnames        , all (`elem` map bndrName vars) varnames -> -          return () +          return (foldl' AppT (ConT basename) (init args ++ [VarT recvar, last args]))        | otherwise -> fail $ "All type parameters but the last one must be uniform over all constructors. "                              ++ "(Return type of constructor " ++ pprint consname ++ ")" @@ -55,7 +90,8 @@ checkRetty astname consname vars retty = do  splitApps :: Type -> (Type, [Type])  splitApps = flip go [] -  where go (AppT t arg) tl = go t (arg : tl) +  where go (ParensT t) tl = go t tl +        go (AppT t arg) tl = go t (arg : tl)          go t tl = (t, tl)  allDistinct :: Ord a => [a] -> Bool @@ -66,3 +102,7 @@ allDistinct l =  bndrName :: TyVarBndr a -> Name  bndrName (PlainTV n _) = n  bndrName (KindedTV n _ _) = n + +infixOf :: Eq a => [a] -> [a] -> Bool +short `infixOf` long = any (`startsWith` short) (tails long) +  where a `startsWith` b = take (length b) a == b  | 
