From 3a71d9c6c61afa3efb6bc190bf1ddae644ca0dff Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 28 Aug 2024 23:28:57 +0200 Subject: WIP TH for non-base ASTs --- sharing-recovery.cabal | 3 +++ test/Arith/NonBase.hs | 45 +++++++++++++++++++++++++++++++++ test/Main.hs | 1 + test/NonBaseTH.hs | 68 ++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 117 insertions(+) create mode 100644 test/Arith/NonBase.hs create mode 100644 test/NonBaseTH.hs diff --git a/sharing-recovery.cabal b/sharing-recovery.cabal index 1a7228c..74cc7dc 100644 --- a/sharing-recovery.cabal +++ b/sharing-recovery.cabal @@ -26,9 +26,12 @@ test-suite test main-is: Main.hs other-modules: Arith + Arith.NonBase + NonBaseTH hs-source-dirs: test build-depends: sharing-recovery, base, + template-haskell, default-language: Haskell2010 ghc-options: -Wall diff --git a/test/Arith/NonBase.hs b/test/Arith/NonBase.hs new file mode 100644 index 0000000..79c4428 --- /dev/null +++ b/test/Arith/NonBase.hs @@ -0,0 +1,45 @@ +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE StandaloneKindSignatures #-} +module Arith.NonBase where + +import Data.Kind +import Data.Type.Equality + +-- import NonBaseTH + + +data Typ t where + TInt :: Typ Int + TBool :: Typ Bool + TPair :: Typ a -> Typ b -> Typ (a, b) + TFun :: Typ a -> Typ b -> Typ (a -> b) +deriving instance Show (Typ t) + +instance TestEquality Typ where + testEquality TInt TInt = Just Refl + testEquality TBool TBool = Just Refl + testEquality (TPair a b) (TPair a' b') + | Just Refl <- testEquality a a' + , Just Refl <- testEquality b b' + = Just Refl + testEquality (TFun a b) (TFun a' b') + | Just Refl <- testEquality a a' + , Just Refl <- testEquality b b' + = Just Refl + testEquality _ _ = Nothing + +data PrimOp a b where + POAddI :: PrimOp (Int, Int) Int + POMulI :: PrimOp (Int, Int) Int + POEqI :: PrimOp (Int, Int) Bool +deriving instance Show (PrimOp a b) + +type Arith :: Type -> Type +data Arith t where + A_Var :: Typ t -> String -> Arith t + A_Let :: String -> Typ a -> Arith a -> Arith b -> Arith b + A_Prim :: PrimOp a b -> Arith a -> Arith b + A_Pair :: Arith a -> Arith b -> Arith (a, b) + A_If :: Arith Bool -> Arith a -> Arith a -> Arith a + A_Mono :: Arith Bool -> Arith Bool diff --git a/test/Main.hs b/test/Main.hs index 1a8d8e1..5a4d335 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -6,6 +6,7 @@ module Main where import Data.Expr.SharingRecovery +import Data.Expr.SharingRecovery.Internal import Arith diff --git a/test/NonBaseTH.hs b/test/NonBaseTH.hs new file mode 100644 index 0000000..712b680 --- /dev/null +++ b/test/NonBaseTH.hs @@ -0,0 +1,68 @@ +{-# LANGUAGE LambdaCase #-} +module NonBaseTH where + +import Data.List (sort) +import Language.Haskell.TH + + +-- | Define a new GADT that is a base-functor-like version of a given existing +-- GADT AST. +-- +-- Remember to use 'lookupTypeName' or 'lookupValueName' instead of normal +-- quotes in case of punning of data types and constructors. +defineBaseAST + :: Name -- ^ Name of the (base-functor-like) data type to define + -> Name -- ^ Name of the GADT to process + -> [Name] -- ^ Constructors to exclude (chiefly Var, Let, Lam) + -> Q [Dec] +defineBaseAST basename astname excludes = do + info <- reify astname + (params, constrs) <- case info of + TyConI (DataD [] _ params Nothing constrs _) -> return (params, constrs) + _ -> fail $ "Unsupported datatype: " ++ pprint astname + + let recvar = mkName "r" + + let detectRec :: BangType -> Q (Maybe Type) + detectRec (_, field) = _ + + let processConstr con = do + (vars, ctx, names, fields, retty) <- case con of + ForallC vars ctx (GadtC names fields retty) -> return (vars, ctx, names, fields, retty) + GadtC names fields retty -> return ([], [], names, fields, retty) + _ -> fail "Unsupported constructors found" + checkRetty astname (head names) vars retty + _ + + constrs' <- concat <$> traverse processConstr constrs + _ + +checkRetty :: Name -> Name -> [TyVarBndr a] -> Type -> Q () +checkRetty astname consname vars retty = do + case splitApps retty of + (ConT name, args) + | name /= astname -> fail $ "Could not parse return type of constructor " ++ pprint consname + | null args -> fail "Expected GADT to have type parameters" + + | Just varnames <- traverse (\case VarT varname -> Just varname ; _ -> Nothing) (init args) + , allDistinct varnames + , all (`elem` map bndrName vars) varnames -> + return () + + | otherwise -> fail $ "All type parameters but the last one must be uniform over all constructors. " + ++ "(Return type of constructor " ++ pprint consname ++ ")" + _ -> fail $ "Could not parse return type of constructor " ++ pprint consname + +splitApps :: Type -> (Type, [Type]) +splitApps = flip go [] + where go (AppT t arg) tl = go t (arg : tl) + go t tl = (t, tl) + +allDistinct :: Ord a => [a] -> Bool +allDistinct l = + let sorted = sort l + in all (uncurry (/=)) (zip sorted (drop 1 sorted)) + +bndrName :: TyVarBndr a -> Name +bndrName (PlainTV n _) = n +bndrName (KindedTV n _ _) = n -- cgit v1.2.3-70-g09d2