aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-08-28 23:28:57 +0200
committerTom Smeding <tom@tomsmeding.com>2024-08-28 23:28:57 +0200
commit3a71d9c6c61afa3efb6bc190bf1ddae644ca0dff (patch)
tree847bf9d34f5c55d218a51e5ec2290b173fcdafef
parent912d262c8aef92657b8991d05b7fe39dcb5b5fd4 (diff)
WIP TH for non-base ASTs
-rw-r--r--sharing-recovery.cabal3
-rw-r--r--test/Arith/NonBase.hs45
-rw-r--r--test/Main.hs1
-rw-r--r--test/NonBaseTH.hs68
4 files changed, 117 insertions, 0 deletions
diff --git a/sharing-recovery.cabal b/sharing-recovery.cabal
index 1a7228c..74cc7dc 100644
--- a/sharing-recovery.cabal
+++ b/sharing-recovery.cabal
@@ -26,9 +26,12 @@ test-suite test
main-is: Main.hs
other-modules:
Arith
+ Arith.NonBase
+ NonBaseTH
hs-source-dirs: test
build-depends:
sharing-recovery,
base,
+ template-haskell,
default-language: Haskell2010
ghc-options: -Wall
diff --git a/test/Arith/NonBase.hs b/test/Arith/NonBase.hs
new file mode 100644
index 0000000..79c4428
--- /dev/null
+++ b/test/Arith/NonBase.hs
@@ -0,0 +1,45 @@
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+module Arith.NonBase where
+
+import Data.Kind
+import Data.Type.Equality
+
+-- import NonBaseTH
+
+
+data Typ t where
+ TInt :: Typ Int
+ TBool :: Typ Bool
+ TPair :: Typ a -> Typ b -> Typ (a, b)
+ TFun :: Typ a -> Typ b -> Typ (a -> b)
+deriving instance Show (Typ t)
+
+instance TestEquality Typ where
+ testEquality TInt TInt = Just Refl
+ testEquality TBool TBool = Just Refl
+ testEquality (TPair a b) (TPair a' b')
+ | Just Refl <- testEquality a a'
+ , Just Refl <- testEquality b b'
+ = Just Refl
+ testEquality (TFun a b) (TFun a' b')
+ | Just Refl <- testEquality a a'
+ , Just Refl <- testEquality b b'
+ = Just Refl
+ testEquality _ _ = Nothing
+
+data PrimOp a b where
+ POAddI :: PrimOp (Int, Int) Int
+ POMulI :: PrimOp (Int, Int) Int
+ POEqI :: PrimOp (Int, Int) Bool
+deriving instance Show (PrimOp a b)
+
+type Arith :: Type -> Type
+data Arith t where
+ A_Var :: Typ t -> String -> Arith t
+ A_Let :: String -> Typ a -> Arith a -> Arith b -> Arith b
+ A_Prim :: PrimOp a b -> Arith a -> Arith b
+ A_Pair :: Arith a -> Arith b -> Arith (a, b)
+ A_If :: Arith Bool -> Arith a -> Arith a -> Arith a
+ A_Mono :: Arith Bool -> Arith Bool
diff --git a/test/Main.hs b/test/Main.hs
index 1a8d8e1..5a4d335 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -6,6 +6,7 @@
module Main where
import Data.Expr.SharingRecovery
+import Data.Expr.SharingRecovery.Internal
import Arith
diff --git a/test/NonBaseTH.hs b/test/NonBaseTH.hs
new file mode 100644
index 0000000..712b680
--- /dev/null
+++ b/test/NonBaseTH.hs
@@ -0,0 +1,68 @@
+{-# LANGUAGE LambdaCase #-}
+module NonBaseTH where
+
+import Data.List (sort)
+import Language.Haskell.TH
+
+
+-- | Define a new GADT that is a base-functor-like version of a given existing
+-- GADT AST.
+--
+-- Remember to use 'lookupTypeName' or 'lookupValueName' instead of normal
+-- quotes in case of punning of data types and constructors.
+defineBaseAST
+ :: Name -- ^ Name of the (base-functor-like) data type to define
+ -> Name -- ^ Name of the GADT to process
+ -> [Name] -- ^ Constructors to exclude (chiefly Var, Let, Lam)
+ -> Q [Dec]
+defineBaseAST basename astname excludes = do
+ info <- reify astname
+ (params, constrs) <- case info of
+ TyConI (DataD [] _ params Nothing constrs _) -> return (params, constrs)
+ _ -> fail $ "Unsupported datatype: " ++ pprint astname
+
+ let recvar = mkName "r"
+
+ let detectRec :: BangType -> Q (Maybe Type)
+ detectRec (_, field) = _
+
+ let processConstr con = do
+ (vars, ctx, names, fields, retty) <- case con of
+ ForallC vars ctx (GadtC names fields retty) -> return (vars, ctx, names, fields, retty)
+ GadtC names fields retty -> return ([], [], names, fields, retty)
+ _ -> fail "Unsupported constructors found"
+ checkRetty astname (head names) vars retty
+ _
+
+ constrs' <- concat <$> traverse processConstr constrs
+ _
+
+checkRetty :: Name -> Name -> [TyVarBndr a] -> Type -> Q ()
+checkRetty astname consname vars retty = do
+ case splitApps retty of
+ (ConT name, args)
+ | name /= astname -> fail $ "Could not parse return type of constructor " ++ pprint consname
+ | null args -> fail "Expected GADT to have type parameters"
+
+ | Just varnames <- traverse (\case VarT varname -> Just varname ; _ -> Nothing) (init args)
+ , allDistinct varnames
+ , all (`elem` map bndrName vars) varnames ->
+ return ()
+
+ | otherwise -> fail $ "All type parameters but the last one must be uniform over all constructors. "
+ ++ "(Return type of constructor " ++ pprint consname ++ ")"
+ _ -> fail $ "Could not parse return type of constructor " ++ pprint consname
+
+splitApps :: Type -> (Type, [Type])
+splitApps = flip go []
+ where go (AppT t arg) tl = go t (arg : tl)
+ go t tl = (t, tl)
+
+allDistinct :: Ord a => [a] -> Bool
+allDistinct l =
+ let sorted = sort l
+ in all (uncurry (/=)) (zip sorted (drop 1 sorted))
+
+bndrName :: TyVarBndr a -> Name
+bndrName (PlainTV n _) = n
+bndrName (KindedTV n _ _) = n