diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-08-29 11:08:50 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-08-29 11:08:50 +0200 |
commit | 1f441a57c55e9d038144f1ec92a54387dcb0ae6d (patch) | |
tree | 60a9eab5b91dc8f557362e77a22ac232daaf7994 /test | |
parent | 3a71d9c6c61afa3efb6bc190bf1ddae644ca0dff (diff) |
TH conversion works somewhat
Diffstat (limited to 'test')
-rw-r--r-- | test/Arith/NonBase.hs | 5 | ||||
-rw-r--r-- | test/NonBaseTH.hs | 68 |
2 files changed, 58 insertions, 15 deletions
diff --git a/test/Arith/NonBase.hs b/test/Arith/NonBase.hs index 79c4428..ea27863 100644 --- a/test/Arith/NonBase.hs +++ b/test/Arith/NonBase.hs @@ -1,12 +1,13 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TemplateHaskell #-} module Arith.NonBase where import Data.Kind import Data.Type.Equality --- import NonBaseTH +import NonBaseTH data Typ t where @@ -43,3 +44,5 @@ data Arith t where A_Pair :: Arith a -> Arith b -> Arith (a, b) A_If :: Arith Bool -> Arith a -> Arith a -> Arith a A_Mono :: Arith Bool -> Arith Bool + +defineBaseAST "ArithF" ''Arith ['A_Var, 'A_Let] (("AF_"++) . drop 2) diff --git a/test/NonBaseTH.hs b/test/NonBaseTH.hs index 712b680..f3c34f3 100644 --- a/test/NonBaseTH.hs +++ b/test/NonBaseTH.hs @@ -1,7 +1,11 @@ {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TemplateHaskellQuotes #-} module NonBaseTH where -import Data.List (sort) +import Control.Monad (when) +import Data.List (sort, foldl', tails) +import Data.Maybe (fromMaybe) import Language.Haskell.TH @@ -11,34 +15,65 @@ import Language.Haskell.TH -- Remember to use 'lookupTypeName' or 'lookupValueName' instead of normal -- quotes in case of punning of data types and constructors. defineBaseAST - :: Name -- ^ Name of the (base-functor-like) data type to define + :: String -- ^ Name of the (base-functor-like) data type to define -> Name -- ^ Name of the GADT to process - -> [Name] -- ^ Constructors to exclude (chiefly Var, Let, Lam) + -> [Name] -- ^ Constructors to exclude (Var and Let, plus any other scoping construct) + -> (String -> String) -- ^ Constructor renaming function -> Q [Dec] -defineBaseAST basename astname excludes = do +defineBaseAST basename astname excludes renameConstr = do info <- reify astname (params, constrs) <- case info of TyConI (DataD [] _ params Nothing constrs _) -> return (params, constrs) _ -> fail $ "Unsupported datatype: " ++ pprint astname - let recvar = mkName "r" + let basename' = mkName basename + recvar = mkName "r" - let detectRec :: BangType -> Q (Maybe Type) - detectRec (_, field) = _ + let detectRec :: Type -> Q (Maybe Type) + detectRec field = do + let (core, args) = splitApps field + case core of + ConT n + | n == astname -> return (Just (VarT recvar `AppT` last args)) + ListT + | [arg] <- args -> fmap (ListT `AppT`) <$> detectRec arg + TupleT k + | length args == k -> fmap (foldl' AppT (TupleT k)) . sequence <$> traverse detectRec args + _ -> do + when (pprint astname `infixOf` pprint field) $ + reportWarning $ "Field\n " ++ pprint field ++ "\nseems to refer to " + ++ pprint astname ++ " in unsupported ways; ignoring those occurrences." + return Nothing let processConstr con = do (vars, ctx, names, fields, retty) <- case con of ForallC vars ctx (GadtC names fields retty) -> return (vars, ctx, names, fields, retty) GadtC names fields retty -> return ([], [], names, fields, retty) _ -> fail "Unsupported constructors found" - checkRetty astname (head names) vars retty - _ + case filter (`notElem` excludes) names of + [] -> return [] + names' -> do + let names'' = map (mkName . renameConstr . nameBase) names' + retty' <- fixRetty basename' astname (head names) recvar vars retty + fields' <- traverse (\(ba, t) -> (ba,) . fromMaybe t <$> detectRec t) fields + return [ForallC (map cleanupBndr vars ++ [PlainTV recvar SpecifiedSpec]) ctx (GadtC names'' fields' retty')] + let params' = map cleanupBndr (init params ++ [PlainTV recvar (), last params]) constrs' <- concat <$> traverse processConstr constrs - _ + return [DataD [] (mkName basename) params' Nothing constrs' []] -checkRetty :: Name -> Name -> [TyVarBndr a] -> Type -> Q () -checkRetty astname consname vars retty = do +-- | Remove `:: Type` annotations because they unnecessarily require the user +-- to enable KindSignatures. Any other annotations we leave, in case the user +-- wrote them and they are necessary. +cleanupBndr :: TyVarBndr a -> TyVarBndr a +cleanupBndr (KindedTV name x k) | isType k = PlainTV name x + where isType StarT = True + isType (ConT n) | n == ''Type = True + isType _ = False +cleanupBndr b = b + +fixRetty :: Name -> Name -> Name -> Name -> [TyVarBndr a] -> Type -> Q Type +fixRetty basename astname consname recvar vars retty = do case splitApps retty of (ConT name, args) | name /= astname -> fail $ "Could not parse return type of constructor " ++ pprint consname @@ -47,7 +82,7 @@ checkRetty astname consname vars retty = do | Just varnames <- traverse (\case VarT varname -> Just varname ; _ -> Nothing) (init args) , allDistinct varnames , all (`elem` map bndrName vars) varnames -> - return () + return (foldl' AppT (ConT basename) (init args ++ [VarT recvar, last args])) | otherwise -> fail $ "All type parameters but the last one must be uniform over all constructors. " ++ "(Return type of constructor " ++ pprint consname ++ ")" @@ -55,7 +90,8 @@ checkRetty astname consname vars retty = do splitApps :: Type -> (Type, [Type]) splitApps = flip go [] - where go (AppT t arg) tl = go t (arg : tl) + where go (ParensT t) tl = go t tl + go (AppT t arg) tl = go t (arg : tl) go t tl = (t, tl) allDistinct :: Ord a => [a] -> Bool @@ -66,3 +102,7 @@ allDistinct l = bndrName :: TyVarBndr a -> Name bndrName (PlainTV n _) = n bndrName (KindedTV n _ _) = n + +infixOf :: Eq a => [a] -> [a] -> Bool +short `infixOf` long = any (`startsWith` short) (tails long) + where a `startsWith` b = take (length b) a == b |