diff options
Diffstat (limited to 'test')
| -rw-r--r-- | test/Arith/NonBase.hs | 45 | ||||
| -rw-r--r-- | test/Main.hs | 1 | ||||
| -rw-r--r-- | test/NonBaseTH.hs | 68 | 
3 files changed, 114 insertions, 0 deletions
diff --git a/test/Arith/NonBase.hs b/test/Arith/NonBase.hs new file mode 100644 index 0000000..79c4428 --- /dev/null +++ b/test/Arith/NonBase.hs @@ -0,0 +1,45 @@ +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE StandaloneKindSignatures #-} +module Arith.NonBase where + +import Data.Kind +import Data.Type.Equality + +-- import NonBaseTH + + +data Typ t where +  TInt :: Typ Int +  TBool :: Typ Bool +  TPair :: Typ a -> Typ b -> Typ (a, b) +  TFun :: Typ a -> Typ b -> Typ (a -> b) +deriving instance Show (Typ t) + +instance TestEquality Typ where +  testEquality TInt TInt = Just Refl +  testEquality TBool TBool = Just Refl +  testEquality (TPair a b) (TPair a' b') +    | Just Refl <- testEquality a a' +    , Just Refl <- testEquality b b' +    = Just Refl +  testEquality (TFun a b) (TFun a' b') +    | Just Refl <- testEquality a a' +    , Just Refl <- testEquality b b' +    = Just Refl +  testEquality _ _ = Nothing + +data PrimOp a b where +  POAddI :: PrimOp (Int, Int) Int +  POMulI :: PrimOp (Int, Int) Int +  POEqI :: PrimOp (Int, Int) Bool +deriving instance Show (PrimOp a b) + +type Arith :: Type -> Type +data Arith t where +  A_Var :: Typ t -> String -> Arith t +  A_Let :: String -> Typ a -> Arith a -> Arith b -> Arith b +  A_Prim :: PrimOp a b -> Arith a -> Arith b +  A_Pair :: Arith a -> Arith b -> Arith (a, b) +  A_If :: Arith Bool -> Arith a -> Arith a -> Arith a +  A_Mono :: Arith Bool -> Arith Bool diff --git a/test/Main.hs b/test/Main.hs index 1a8d8e1..5a4d335 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -6,6 +6,7 @@  module Main where  import Data.Expr.SharingRecovery +import Data.Expr.SharingRecovery.Internal  import Arith diff --git a/test/NonBaseTH.hs b/test/NonBaseTH.hs new file mode 100644 index 0000000..712b680 --- /dev/null +++ b/test/NonBaseTH.hs @@ -0,0 +1,68 @@ +{-# LANGUAGE LambdaCase #-} +module NonBaseTH where + +import Data.List (sort) +import Language.Haskell.TH + + +-- | Define a new GADT that is a base-functor-like version of a given existing +-- GADT AST. +-- +-- Remember to use 'lookupTypeName' or 'lookupValueName' instead of normal +-- quotes in case of punning of data types and constructors. +defineBaseAST +  :: Name  -- ^ Name of the (base-functor-like) data type to define +  -> Name  -- ^ Name of the GADT to process +  -> [Name]  -- ^ Constructors to exclude (chiefly Var, Let, Lam) +  -> Q [Dec] +defineBaseAST basename astname excludes = do +  info <- reify astname +  (params, constrs) <- case info of +    TyConI (DataD [] _ params Nothing constrs _) -> return (params, constrs) +    _ -> fail $ "Unsupported datatype: " ++ pprint astname + +  let recvar = mkName "r" + +  let detectRec :: BangType -> Q (Maybe Type) +      detectRec (_, field) = _ + +  let processConstr con = do +        (vars, ctx, names, fields, retty) <- case con of +          ForallC vars ctx (GadtC names fields retty) -> return (vars, ctx, names, fields, retty) +          GadtC names fields retty -> return ([], [], names, fields, retty) +          _ -> fail "Unsupported constructors found" +        checkRetty astname (head names) vars retty +        _ + +  constrs' <- concat <$> traverse processConstr constrs +  _ + +checkRetty :: Name -> Name -> [TyVarBndr a] -> Type -> Q () +checkRetty astname consname vars retty = do +  case splitApps retty of +    (ConT name, args) +      | name /= astname -> fail $ "Could not parse return type of constructor " ++ pprint consname +      | null args -> fail "Expected GADT to have type parameters" + +      | Just varnames <- traverse (\case VarT varname -> Just varname ; _ -> Nothing) (init args) +      , allDistinct varnames +      , all (`elem` map bndrName vars) varnames -> +          return () + +      | otherwise -> fail $ "All type parameters but the last one must be uniform over all constructors. " +                            ++ "(Return type of constructor " ++ pprint consname ++ ")" +    _ -> fail $ "Could not parse return type of constructor " ++ pprint consname + +splitApps :: Type -> (Type, [Type]) +splitApps = flip go [] +  where go (AppT t arg) tl = go t (arg : tl) +        go t tl = (t, tl) + +allDistinct :: Ord a => [a] -> Bool +allDistinct l = +  let sorted = sort l +  in all (uncurry (/=)) (zip sorted (drop 1 sorted)) + +bndrName :: TyVarBndr a -> Name +bndrName (PlainTV n _) = n +bndrName (KindedTV n _ _) = n  | 
