diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-09-02 12:05:11 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-09-02 12:05:11 +0200 |
commit | b4906aed78519919f556053c4ecc0b646df3cece (patch) | |
tree | c1a26161c4408d76bebeecf1d42957d10fdfb13a | |
parent | ecb5f11260e39be824b1b45c5bcb93bb9e137132 (diff) |
-rw-r--r-- | sharing-recovery.cabal | 1 | ||||
-rw-r--r-- | test/Arith/NonBase.hs | 4 | ||||
-rw-r--r-- | test/NonBaseTH.hs | 231 |
3 files changed, 130 insertions, 106 deletions
diff --git a/sharing-recovery.cabal b/sharing-recovery.cabal index 74cc7dc..9e34bab 100644 --- a/sharing-recovery.cabal +++ b/sharing-recovery.cabal @@ -32,6 +32,7 @@ test-suite test build-depends: sharing-recovery, base, + containers, template-haskell, default-language: Haskell2010 ghc-options: -Wall diff --git a/test/Arith/NonBase.hs b/test/Arith/NonBase.hs index ea27863..f5d458e 100644 --- a/test/Arith/NonBase.hs +++ b/test/Arith/NonBase.hs @@ -45,4 +45,6 @@ data Arith t where A_If :: Arith Bool -> Arith a -> Arith a -> Arith a A_Mono :: Arith Bool -> Arith Bool -defineBaseAST "ArithF" ''Arith ['A_Var, 'A_Let] (("AF_"++) . drop 2) +defineBaseAST + "ArithF" ''Arith ['A_Var, 'A_Let] (("AF_"++) . drop 2) + "arithConv" ''Typ (\_ _ _ -> [| error "Lambda impossible" |]) diff --git a/test/NonBaseTH.hs b/test/NonBaseTH.hs index c6cf54b..4741ea0 100644 --- a/test/NonBaseTH.hs +++ b/test/NonBaseTH.hs @@ -1,40 +1,39 @@ {-# LANGUAGE LambdaCase #-} -{-# LANGUAGE TupleSections #-} {-# LANGUAGE TemplateHaskellQuotes #-} -module NonBaseTH where +{-# LANGUAGE TupleSections #-} +module NonBaseTH ( + defineBaseAST, +) where import Control.Monad (when) import Data.List (sort, foldl', tails) -import Data.Maybe (fromMaybe) +import qualified Data.Map.Strict as Map +import Data.Maybe (catMaybes) import Language.Haskell.TH +import Data.Expr.SharingRecovery + -- | Non-base AST data type data NBAST = NBAST String [TyVarBndr ()] [NBCon] data NBCon = NBCon - [String] -- ^ Names of the constructors defined with this clause - [TyVarBndrSpec] -- ^ Type variables in scope in the clause - Cxt -- ^ Constraints on the constructor type - [(Bang, NBField)] -- ^ Constructor fields - [Name] -- ^ All but the last type parameter in the result type - Type -- ^ The last type parameter in the result type - -data NBField - = NBFRecur ([Type] -> Type) - -- ^ Context: takes replacement recursive positions and wraps - -- them in whatever structure they came in (e.g. a tuple). - [Type] - -- ^ The last type parameter of all the recursive positions in - -- this field. (The other type parameters are fixed anyway.) + [String] -- ^ Names of the constructors defined with this clause + [TyVarBndrSpec] -- ^ Type variables in scope in the clause + Cxt -- ^ Constraints on the constructor type + [(Bang, NBField Type)] -- ^ Constructor fields + [Name] -- ^ All but the last type parameter in the result type + Type -- ^ The last type parameter in the result type + +data NBField t + = NBFList (NBField t) + | NBFTuple [NBField t] + | NBFRecur t -- ^ The last type parameter of this recursive position. (The + -- other type parameters are fixed anyway.) | NBFConst Type -wrapNBField :: (Type -> Type) -> NBField -> NBField -wrapNBField f (NBFRecur ctx param) = NBFRecur (f . ctx) param -wrapNBField f (NBFConst t) = NBFConst (f t) - -parseNBAST :: Info -> Q NBAST -parseNBAST info = do +parseNBAST :: [String] -> Info -> Q NBAST +parseNBAST excludes info = do (astname, params, constrs) <- case info of TyConI (DataD [] astname params Nothing constrs _) -> return (astname, params, constrs) _ -> fail "Unsupported datatype" @@ -47,29 +46,17 @@ parseNBAST info = do ConT n | n == astname -> if not (null args) && init args == map VarT retpars - then return (NBFRecur head [last args]) + then return (NBFRecur (last args)) else fail $ "Field\n " ++ pprint field ++ "\nis recursive, but with different type parameters than " ++ "the return type of this constructor. All but the last type " ++ "parameter of the GADT must be uniform over the entire AST." ListT - | [arg] <- args -> wrapNBField (ListT `AppT`) <$> parseField retpars arg + | [arg] <- args -> NBFList <$> parseField retpars arg TupleT k - | length args == k -> do - positions <- traverse (parseField retpars) args - case traverse (\case NBFConst t -> Just t ; _ -> Nothing) positions of - Just l -> return (NBFConst (foldl' AppT (TupleT k) l)) - Nothing -> do - let indices = concatMap (\case NBFRecur _ is -> is ; NBFConst _ -> []) positions - let reconstruct [] [] = [] - reconstruct [] _ = error "Invalid number of replacing recursive positions" - reconstruct (NBFRecur ctx pars : poss) repls = - let (pre, post) = splitAt (length pars) repls - in ctx pre : reconstruct poss post - reconstruct (NBFConst _ : poss) repls = reconstruct poss repls - return (NBFRecur (foldl' AppT (TupleT k) . reconstruct positions) indices) + | length args == k -> NBFTuple <$> traverse (parseField retpars) args _ -> do when (pprint astname `infixOf` pprint field) $ @@ -77,16 +64,23 @@ parseNBAST info = do ++ pprint astname ++ " in unsupported ways; ignoring those occurrences." return (NBFConst field) - let parseConstr con = do - (vars, ctx, names, fields, retty) <- case con of - ForallC vars ctx (GadtC names fields retty) -> return (vars, ctx, names, fields, retty) - GadtC names fields retty -> return ([], [], names, fields, retty) - _ -> fail "Unsupported constructors found" + let splitConstr (ForallC vars ctx (GadtC names fields retty)) + | names'@(_:_) <- filter (`notElem` excludes) (map nameBase names) = + return (Just (vars, ctx, names', fields, retty)) + | otherwise = return Nothing + splitConstr c@GadtC{} = splitConstr (ForallC [] [] c) + splitConstr c = + let names = case c of + NormalC n _ -> Just (show n) + _ -> Just (show c) + in fail $ "Unsupported constructors found" ++ maybe "" (\s -> ": " ++ show s) names + + let parseConstr (vars, ctx, names, fields, retty) = do (retpars, retindex) <- parseRetty astname (head names) retty fields' <- traverse (\(ba, t) -> (ba,) <$> parseField retpars t) fields - return (NBCon (map nameBase names) vars ctx fields' retpars retindex) + return (NBCon names vars ctx fields' retpars retindex) - constrs' <- traverse parseConstr constrs + constrs' <- traverse parseConstr =<< catMaybes <$> traverse splitConstr constrs return (NBAST (nameBase astname) params constrs') @@ -100,48 +94,91 @@ defineBaseAST -> Name -- ^ Name of the GADT to process -> [Name] -- ^ Constructors to exclude (Var and Let, plus any other scoping construct) -> (String -> String) -- ^ Constructor renaming function + -> String -- ^ Name of base -> nonbase conversion function to define + -> Name -- ^ Type of singleton types + -> (ExpQ -> ExpQ -> ExpQ -> ExpQ) -- ^ Lambda: typ (a -> b) -> typ a -> AST b -> AST (a -> b) + -> (ExpQ -> ExpQ -> ExpQ -> ExpQ) -- ^ Let: typ a -> String -> AST a -> AST b -> AST b + -> (ExpQ -> ExpQ -> ExpQ -> ExpQ) -- ^ Var: typ t -> String -> AST b -> AST b -> Q [Dec] -defineBaseAST basename astname excludes renameConstr = do - info <- reify astname - (params, constrs) <- case info of - TyConI (DataD [] _ params Nothing constrs _) -> return (params, constrs) - _ -> fail $ "Unsupported datatype: " ++ pprint astname +defineBaseAST basename astname excludes renameConstr bnConvName typName mkLam mkLet mkVar = do + NBAST _ params constrs <- parseNBAST (map nameBase excludes) =<< reify astname + + -- Build the data type let basename' = mkName basename - recvar = mkName "r" + conNameMap = Map.fromList [(nbname, mkName (renameConstr nbname)) + | NBCon ns _ _ _ _ _ <- constrs, nbname <- ns] - let detectRec :: Type -> Q (Maybe Type) - detectRec field = do - let (core, args) = splitApps field - case core of - ConT n - | n == astname -> return (Just (VarT recvar `AppT` last args)) - ListT - | [arg] <- args -> fmap (ListT `AppT`) <$> detectRec arg - TupleT k - | length args == k -> fmap (foldl' AppT (TupleT k)) . sequence <$> traverse detectRec args - _ -> do - when (pprint astname `infixOf` pprint field) $ - reportWarning $ "Field\n " ++ pprint field ++ "\nseems to refer to " - ++ pprint astname ++ " in unsupported ways; ignoring those occurrences." - return Nothing - - let processConstr con = do - (vars, ctx, names, fields, retty) <- case con of - ForallC vars ctx (GadtC names fields retty) -> return (vars, ctx, names, fields, retty) - GadtC names fields retty -> return ([], [], names, fields, retty) - _ -> fail "Unsupported constructors found" - case filter (`notElem` excludes) names of - [] -> return [] - names' -> do - let names'' = map (mkName . renameConstr . nameBase) names' - retty' <- fixRetty basename' astname (head names) recvar vars retty - fields' <- traverse (\(ba, t) -> (ba,) . fromMaybe t <$> detectRec t) fields - return [ForallC (map cleanupBndr vars ++ [PlainTV recvar SpecifiedSpec]) ctx (GadtC names'' fields' retty')] + let recvar = mkName "r" + + let processField (NBFRecur idx) = VarT recvar `AppT` idx + processField (NBFConst t) = t + processField (NBFList f) = AppT ListT (processField f) + processField (NBFTuple fs) = foldl' AppT (TupleT (length fs)) (map processField fs) + + let processConstr (NBCon names vars ctx fields retparams retindex) = do + let names' = map (conNameMap Map.!) names + let fields' = map (\(ba, f) -> (ba, processField f)) fields + let retty' = foldl' AppT (ConT basename') (map VarT retparams ++ [VarT recvar, retindex]) + return [ForallC (map cleanupBndr vars ++ [PlainTV recvar SpecifiedSpec]) + ctx (GadtC names' fields' retty')] let params' = map cleanupBndr (init params ++ [PlainTV recvar (), last params]) constrs' <- concat <$> traverse processConstr constrs - return [DataD [] (mkName basename) params' Nothing constrs' []] + let datatype = DataD [] (mkName basename) params' Nothing constrs' [] + + -- Build the B->N conversion function + + let bnConvName' = mkName bnConvName + envparam = VarT (mkName "env") + tparam = VarT (mkName "t") + ftype = foldl' AppT (ConT basename') (map (VarT . bndrName) (init params)) + arrow a b = ArrowT `AppT` a `AppT` b + + let bnConvSig = + SigD bnConvName' $ + (ConT ''BExpr `AppT` ConT typName `AppT` envparam `AppT` ftype `AppT` tparam) + `arrow` + foldl' AppT (ConT astname) (map (VarT . bndrName) (init params) ++ [tparam]) + + let clause' pats ex = Clause pats (NormalB ex) [] + -- backConMap = Map.fromList [(renameConstr nbname, nbname) + -- | NBCon ns _ _ _ _ _ <- constrs, nbname <- ns] + reconstructField (NBFRecur _) = do + r <- newName "r" + return (VarP r, VarE bnConvName' `AppE` VarE r) + reconstructField (NBFConst _) = do + x <- newName "x" + return (VarP x, VarE x) + reconstructField (NBFList f) = do + (pat, ex) <- reconstructField f + l <- newName "l" + return (VarP l, VarE 'map `AppE` LamE [pat] ex `AppE` VarE l) + reconstructField (NBFTuple fs) = do + (pats, exps) <- unzip <$> traverse reconstructField fs + return (TupP pats, TupE (map Just exps)) + + let tyarg1 = mkName "ty1" + tyarg2 = mkName "ty2" + astarg = mkName "ast" + mkseq a b = VarE 'seq `AppE` a `AppE` b + infixr `mkseq` + + bnConvFun <- + fmap (FunD bnConvName') . sequence $ + [do (pats, exps) <- unzip <$> traverse (reconstructField . snd) fields + return $ + clause' [ConP 'BOp [] [WildP, ConP (conNameMap Map.! nbname) [] pats]] + (foldl' AppE (ConE (mkName nbname)) exps) + | NBCon names _ _ fields _ _ <- constrs + , nbname <- names] + ++ + [do body <- mkLam (varE tyarg1) (varE tyarg2) (varE astarg) + return $ clause' [ConP 'BLam [] [VarP tyarg1, VarP tyarg2, VarP astarg]] + -- put them in a useless seq so that they're not unused + (TupE [Just (VarE tyarg1), Just (VarE tyarg2), Just (VarE astarg)] `mkseq` body)] + + return [datatype, bnConvSig, bnConvFun] -- | Remove `:: Type` annotations because they unnecessarily require the user -- to enable KindSignatures. Any other annotations we leave, in case the user @@ -153,36 +190,24 @@ cleanupBndr (KindedTV name x k) | isType k = PlainTV name x isType _ = False cleanupBndr b = b -parseRetty :: Name -> Name -> Type -> Q ([Name], Type) +bndrName :: TyVarBndr a -> Name +bndrName (PlainTV name _) = name +bndrName (KindedTV name _ _) = name + +parseRetty :: Name -> String -> Type -> Q ([Name], Type) parseRetty astname consname retty = do case splitApps retty of (ConT name, args) - | name /= astname -> fail $ "Could not parse return type of constructor " ++ pprint consname + | name /= astname -> fail $ "Could not parse return type of constructor " ++ consname | null args -> fail "Expected GADT to have type parameters" | Just varnames <- traverse (\case VarT varname -> Just varname ; _ -> Nothing) (init args) , allDistinct varnames -> - return (init varnames, last args) - - | otherwise -> fail $ "All type parameters but the last one must be uniform over all constructors. " - ++ "(Return type of constructor " ++ pprint consname ++ ")" - _ -> fail $ "Could not parse return type of constructor " ++ pprint consname - -fixRetty :: Name -> Name -> Name -> Name -> [TyVarBndr a] -> Type -> Q Type -fixRetty basename astname consname recvar vars retty = do - case splitApps retty of - (ConT name, args) - | name /= astname -> fail $ "Could not parse return type of constructor " ++ pprint consname - | null args -> fail "Expected GADT to have type parameters" - - | Just varnames <- traverse (\case VarT varname -> Just varname ; _ -> Nothing) (init args) - , allDistinct varnames - , all (`elem` map bndrName vars) varnames -> - return (foldl' AppT (ConT basename) (init args ++ [VarT recvar, last args])) + return (varnames, last args) | otherwise -> fail $ "All type parameters but the last one must be uniform over all constructors. " - ++ "(Return type of constructor " ++ pprint consname ++ ")" - _ -> fail $ "Could not parse return type of constructor " ++ pprint consname + ++ "(Return type of constructor " ++ consname ++ ")" + _ -> fail $ "Could not parse return type of constructor " ++ consname splitApps :: Type -> (Type, [Type]) splitApps = flip go [] @@ -195,10 +220,6 @@ allDistinct l = let sorted = sort l in all (uncurry (/=)) (zip sorted (drop 1 sorted)) -bndrName :: TyVarBndr a -> Name -bndrName (PlainTV n _) = n -bndrName (KindedTV n _ _) = n - infixOf :: Eq a => [a] -> [a] -> Bool short `infixOf` long = any (`startsWith` short) (tails long) where a `startsWith` b = take (length b) a == b |