{-# LANGUAGE RankNTypes #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE QuantifiedConstraints #-} module Arith where import Data.Type.Equality import Data.Expr.SharingRecovery data Typ t where TInt :: Typ Int TBool :: Typ Bool TPair :: Typ a -> Typ b -> Typ (a, b) TFun :: Typ a -> Typ b -> Typ (a -> b) deriving instance Show (Typ t) instance TestEquality Typ where testEquality TInt TInt = Just Refl testEquality TBool TBool = Just Refl testEquality (TPair a b) (TPair a' b') | Just Refl <- testEquality a a' , Just Refl <- testEquality b b' = Just Refl testEquality (TFun a b) (TFun a' b') | Just Refl <- testEquality a a' , Just Refl <- testEquality b b' = Just Refl testEquality _ _ = Nothing class KnownType t where τ :: Typ t instance KnownType Int where τ = TInt instance KnownType Bool where τ = TBool instance (KnownType a, KnownType b) => KnownType (a, b) where τ = TPair τ τ instance (KnownType a, KnownType b) => KnownType (a -> b) where τ = TFun τ τ data PrimOp a b where POAddI :: PrimOp (Int, Int) Int POMulI :: PrimOp (Int, Int) Int POEqI :: PrimOp (Int, Int) Bool deriving instance Show (PrimOp a b) opType2 :: PrimOp a b -> Typ b opType2 = \case POAddI -> TInt POMulI -> TInt POEqI -> TBool data Fixity = Infix | Prefix deriving (Show) primOpPrec :: PrimOp a b -> (Int, (Int, Int)) primOpPrec POAddI = (6, (6, 7)) primOpPrec POMulI = (7, (7, 8)) primOpPrec POEqI = (4, (5, 5)) prettyPrimOp :: Fixity -> PrimOp a b -> ShowS prettyPrimOp fix op = let s = case op of POAddI -> "+" POMulI -> "*" POEqI -> "==" in showString $ case fix of Infix -> s Prefix -> "(" ++ s ++ ")" data ArithF r t where A_Prim :: PrimOp a b -> r a -> ArithF r b A_Pair :: r a -> r b -> ArithF r (a, b) A_If :: r Bool -> r a -> r a -> ArithF r a deriving instance (forall a. Show (r a)) => Show (ArithF r t) instance Functor1 ArithF instance Traversable1 ArithF where traverse1 f (A_Prim op x) = A_Prim op <$> f x traverse1 f (A_Pair x y) = A_Pair <$> f x <*> f y traverse1 f (A_If x y z) = A_If <$> f x <*> f y <*> f z prettyArithF :: Monad m => (forall a. Int -> BExpr Typ env ArithF a -> m ShowS) -> Int -> ArithF (BExpr Typ env ArithF) t -> m ShowS prettyArithF pr d = \case A_Prim op (BOp _ (A_Pair a b)) -> do let (dop, (dopL, dopR)) = primOpPrec op a' <- pr dopL a b' <- pr dopR b return $ showParen (d > dop) $ a' . showString " " . prettyPrimOp Infix op . showString " " . b' A_Prim op (BLet ty rhs e) -> pr d (BLet ty rhs (BOp (opType2 op) (A_Prim op e))) A_Prim op arg -> do arg' <- pr 11 arg return $ showParen (d > 10) $ prettyPrimOp Prefix op . showString " " . arg' A_Pair a b -> do a' <- pr 0 a b' <- pr 0 b return $ showString "(" . a' . showString ", " . b' . showString ")" A_If a b c -> do a' <- pr 0 a b' <- pr 0 b c' <- pr 0 c return $ showParen (d > 0) $ showString "if " . a' . showString " then " . b' . showString " else " . c'