diff options
Diffstat (limited to 'test')
| -rw-r--r-- | test/NonBaseTH.hs | 96 | 
1 files changed, 96 insertions, 0 deletions
diff --git a/test/NonBaseTH.hs b/test/NonBaseTH.hs index f3c34f3..c6cf54b 100644 --- a/test/NonBaseTH.hs +++ b/test/NonBaseTH.hs @@ -9,6 +9,87 @@ import Data.Maybe (fromMaybe)  import Language.Haskell.TH +-- | Non-base AST data type +data NBAST = NBAST String [TyVarBndr ()] [NBCon] + +data NBCon = NBCon +    [String]           -- ^ Names of the constructors defined with this clause +    [TyVarBndrSpec]    -- ^ Type variables in scope in the clause +    Cxt                -- ^ Constraints on the constructor type +    [(Bang, NBField)]  -- ^ Constructor fields +    [Name]             -- ^ All but the last type parameter in the result type +    Type               -- ^ The last type parameter in the result type + +data NBField +  = NBFRecur ([Type] -> Type) +                -- ^ Context: takes replacement recursive positions and wraps +                -- them in whatever structure they came in (e.g. a tuple). +             [Type] +                -- ^ The last type parameter of all the recursive positions in +                -- this field. (The other type parameters are fixed anyway.) +  | NBFConst Type + +wrapNBField :: (Type -> Type) -> NBField -> NBField +wrapNBField f (NBFRecur ctx param) = NBFRecur (f . ctx) param +wrapNBField f (NBFConst t) = NBFConst (f t) + +parseNBAST :: Info -> Q NBAST +parseNBAST info = do +  (astname, params, constrs) <- case info of +    TyConI (DataD [] astname params Nothing constrs _) -> return (astname, params, constrs) +    _ -> fail "Unsupported datatype" + +  -- In this function we use parameter/index terminology: parameters are +  -- uniform, indices vary inside an AST. +  let parseField retpars field = do +        let (core, args) = splitApps field +        case core of +          ConT n +            | n == astname -> +                if not (null args) && init args == map VarT retpars +                  then return (NBFRecur head [last args]) +                  else fail $ "Field\n  " ++ pprint field +                              ++ "\nis recursive, but with different type parameters than " +                              ++ "the return type of this constructor. All but the last type " +                              ++ "parameter of the GADT must be uniform over the entire AST." + +          ListT +            | [arg] <- args -> wrapNBField (ListT `AppT`) <$> parseField retpars arg + +          TupleT k +            | length args == k -> do +                positions <- traverse (parseField retpars) args +                case traverse (\case NBFConst t -> Just t ; _ -> Nothing) positions of +                  Just l -> return (NBFConst (foldl' AppT (TupleT k) l)) +                  Nothing -> do +                    let indices = concatMap (\case NBFRecur _ is -> is ; NBFConst _ -> []) positions +                    let reconstruct [] [] = [] +                        reconstruct [] _ = error "Invalid number of replacing recursive positions" +                        reconstruct (NBFRecur ctx pars : poss) repls = +                          let (pre, post) = splitAt (length pars) repls +                          in ctx pre : reconstruct poss post +                        reconstruct (NBFConst _ : poss) repls = reconstruct poss repls +                    return (NBFRecur (foldl' AppT (TupleT k) . reconstruct positions) indices) + +          _ -> do +            when (pprint astname `infixOf` pprint field) $ +              reportWarning $ "Field\n  " ++ pprint field ++ "\nseems to refer to " +                              ++ pprint astname ++ " in unsupported ways; ignoring those occurrences." +            return (NBFConst field) + +  let parseConstr con = do +        (vars, ctx, names, fields, retty) <- case con of +          ForallC vars ctx (GadtC names fields retty) -> return (vars, ctx, names, fields, retty) +          GadtC names fields retty -> return ([], [], names, fields, retty) +          _ -> fail "Unsupported constructors found" +        (retpars, retindex) <- parseRetty astname (head names) retty +        fields' <- traverse (\(ba, t) -> (ba,) <$> parseField retpars t) fields +        return (NBCon (map nameBase names) vars ctx fields' retpars retindex) + +  constrs' <- traverse parseConstr constrs +  return (NBAST (nameBase astname) params constrs') + +  -- | Define a new GADT that is a base-functor-like version of a given existing  -- GADT AST.  -- @@ -72,6 +153,21 @@ cleanupBndr (KindedTV name x k) | isType k = PlainTV name x          isType _ = False  cleanupBndr b = b +parseRetty :: Name -> Name -> Type -> Q ([Name], Type) +parseRetty astname consname retty = do +  case splitApps retty of +    (ConT name, args) +      | name /= astname -> fail $ "Could not parse return type of constructor " ++ pprint consname +      | null args -> fail "Expected GADT to have type parameters" + +      | Just varnames <- traverse (\case VarT varname -> Just varname ; _ -> Nothing) (init args) +      , allDistinct varnames -> +          return (init varnames, last args) + +      | otherwise -> fail $ "All type parameters but the last one must be uniform over all constructors. " +                            ++ "(Return type of constructor " ++ pprint consname ++ ")" +    _ -> fail $ "Could not parse return type of constructor " ++ pprint consname +  fixRetty :: Name -> Name -> Name -> Name -> [TyVarBndr a] -> Type -> Q Type  fixRetty basename astname consname recvar vars retty = do    case splitApps retty of  | 
