summaryrefslogtreecommitdiff
path: root/src/AST/Pretty.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST/Pretty.hs')
-rw-r--r--src/AST/Pretty.hs225
1 files changed, 146 insertions, 79 deletions
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 24bacdb..4190f32 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -1,16 +1,26 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeOperators #-}
-module AST.Pretty (ppExpr, ppTy) where
+module AST.Pretty (ppExpr, ppTy, PrettyX(..)) where
import Control.Monad (ap)
import Data.List (intersperse)
import Data.Functor.Const
+import Data.String (fromString)
+import Prettyprinter
+import Prettyprinter.Render.String
+
+import qualified Data.Text.Lazy as TL
+import qualified Prettyprinter.Render.Terminal as PT
+import System.Console.ANSI (hSupportsANSI)
+import System.IO (stdout)
+import System.IO.Unsafe (unsafePerformIO)
import AST
import AST.Count
@@ -18,6 +28,17 @@ import CHAD.Types
import Data
+class PrettyX x where
+ prettyX :: x t -> String
+
+ prettyXsuffix :: x t -> String
+ prettyXsuffix x = "<" ++ prettyX x ++ ">"
+
+instance PrettyX (Const ()) where
+ prettyX _ = ""
+ prettyXsuffix _ = ""
+
+
type SVal = SList (Const String)
newtype M a = M { runM :: Int -> (a, Int) }
@@ -43,8 +64,8 @@ genNameIfUsedIn' prefix ty idx ex
genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String
genNameIfUsedIn = genNameIfUsedIn' "x"
-ppExpr :: SList f env -> Expr x env t -> String
-ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) ""
+ppExpr :: PrettyX x => SList f env -> Expr x env t -> String
+ppExpr senv e = render $ fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1)
where
mkVal :: SList f env -> M (SVal env)
mkVal SNil = return SNil
@@ -53,34 +74,35 @@ ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) ""
name <- genName
return (Const name `SCons` val)
-ppExpr' :: Int -> SVal env -> Expr x env t -> M ShowS
-ppExpr' d val = \case
- EVar _ _ i -> return $ showString $ getConst $ slistIdx val i
+ppExpr' :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc
+ppExpr' d val expr = case expr of
+ EVar _ _ i -> return $ ppString (getConst (slistIdx val i)) <> ppX expr
e@ELet{} -> ppExprLet d val e
EPair _ a b -> do
a' <- ppExpr' 0 val a
b' <- ppExpr' 0 val b
- return $ showString "(" . a' . showString ", " . b' . showString ")"
+ return $ group $ flatAlt (align $ ppString "(" <> a' <> hardline <> ppString "," <> b' <> ppString ")" <> ppX expr)
+ (ppString "(" <> a' <> ppString "," <+> b' <> ppString ")" <> ppX expr)
EFst _ e -> do
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "fst " . e'
+ return $ ppParen (d > 10) $ ppString "fst" <> ppX expr <+> e'
ESnd _ e -> do
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "snd " . e'
+ return $ ppParen (d > 10) $ ppString "snd" <> ppX expr <+> e'
- ENil _ -> return $ showString "()"
+ ENil _ -> return $ ppString "()"
EInl _ _ e -> do
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "Inl " . e'
+ return $ ppParen (d > 10) $ ppString "Inl" <> ppX expr <+> e'
EInr _ _ e -> do
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "Inr " . e'
+ return $ ppParen (d > 10) $ ppString "Inr" <> ppX expr <+> e'
ECase _ e a b -> do
e' <- ppExpr' 0 val e
@@ -89,15 +111,17 @@ ppExpr' d val = \case
a' <- ppExpr' 0 (Const name1 `SCons` val) a
name2 <- genNameIfUsedIn t2 IZ b
b' <- ppExpr' 0 (Const name2 `SCons` val) b
- return $ showParen (d > 0) $
- showString "case " . e' . showString (" of { Inl " ++ name1 ++ " -> ") . a'
- . showString (" ; Inr " ++ name2 ++ " -> ") . b' . showString " }"
+ return $ ppParen (d > 0) $
+ hang 2 $
+ annotate AKey (ppString "case") <> ppX expr <+> e' <+> annotate AKey (ppString "of")
+ <> hardline <> ppString "Inl" <+> ppString name1 <+> ppString "->" <+> a'
+ <> hardline <> ppString "Inr" <+> ppString name2 <+> ppString "->" <+> b'
- ENothing _ _ -> return $ showString "nothing"
+ ENothing _ _ -> return $ ppString "Nothing"
EJust _ e -> do
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "Just " . e'
+ return $ ppParen (d > 10) $ ppString "Just" <> ppX expr <+> e'
EMaybe _ a b e -> do
let STMaybe t = typeOf e
@@ -105,18 +129,23 @@ ppExpr' d val = \case
name <- genNameIfUsedIn t IZ b
b' <- ppExpr' 11 (Const name `SCons` val) b
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $
- showString "maybe " . a' . showString " " . b' . showString " " . e'
+ return $ ppParen (d > 10) $
+ ppApp (ppString "maybe" <> ppX expr) [a', b', e']
EConstArr _ _ ty v
- | Dict <- scalRepIsShow ty -> return $ showsPrec d v
+ | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr
EBuild _ n a b -> do
a' <- ppExpr' 11 val a
name <- genNameIfUsedIn' "i" (tTup (sreplicate n tIx)) IZ b
e' <- ppExpr' 0 (Const name `SCons` val) b
- return $ showParen (d > 10) $
- showString "build " . a' . showString (" (\\" ++ name ++ " -> ") . e' . showString ")"
+ return $ group $ flatAlt
+ (ppParen (d > 0) $
+ hang 2 $
+ annotate AHighlight (ppString "build") <> ppX expr <+> a'
+ <+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->"
+ <> hardline <> e')
+ (ppApp (annotate AHighlight (ppString "build") <> ppX expr) [a', ppLam [ppString name] e'])
EFold1Inner _ a b c -> do
name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a
@@ -124,65 +153,64 @@ ppExpr' d val = \case
a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a
b' <- ppExpr' 11 val b
c' <- ppExpr' 11 val c
- return $ showParen (d > 10) $
- showString ("fold1i (\\" ++ name1 ++ " " ++ name2 ++ " -> ") . a'
- . showString ") " . b' . showString " " . c'
+ return $ ppParen (d > 10) $
+ ppApp (annotate AHighlight (ppString "fold1i") <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c']
ESum1Inner _ e -> do
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "sum1i " . e'
+ return $ ppParen (d > 10) $ ppString "sum1i" <> ppX expr <+> e'
EUnit _ e -> do
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "unit " . e'
+ return $ ppParen (d > 10) $ ppString "unit" <> ppX expr <+> e'
EReplicate1Inner _ a b -> do
a' <- ppExpr' 11 val a
b' <- ppExpr' 11 val b
- return $ showParen (d > 10) $ showString "replicate1i " . a' . showString " " . b'
+ return $ ppParen (d > 10) $ ppApp (ppString "replicate1i" <> ppX expr) [a', b']
EMaximum1Inner _ e -> do
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "maximum1i " . e'
+ return $ ppParen (d > 10) $ ppString "maximum1i" <> ppX expr <+> e'
EMinimum1Inner _ e -> do
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "minimum1i " . e'
+ return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e'
EConst _ ty v
- | Dict <- scalRepIsShow ty -> return $ showsPrec d v
+ | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr
EIdx0 _ e -> do
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "idx0 " . e'
+ return $ ppParen (d > 10) $ ppString "idx0" <> ppX expr <+> e'
EIdx1 _ a b -> do
a' <- ppExpr' 9 val a
b' <- ppExpr' 9 val b
- return $ showParen (d > 8) $ a' . showString " .! " . b'
+ return $ ppParen (d > 8) $ a' <+> ppString ".!" <> ppX expr <+> b'
EIdx _ a b -> do
a' <- ppExpr' 9 val a
b' <- ppExpr' 10 val b
- return $ showParen (d > 8) $
- a' . showString " ! " . b'
+ return $ ppParen (d > 8) $
+ a' <+> ppString "!" <+> b'
EShape _ e -> do
e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "shape " . e'
+ return $ ppParen (d > 10) $ ppString "shape" <> ppX expr <+> e'
EOp _ op (EPair _ a b)
| (Infix, ops) <- operator op -> do
a' <- ppExpr' 9 val a
b' <- ppExpr' 9 val b
- return $ showParen (d > 8) $ a' . showString (" " ++ ops ++ " ") . b'
+ return $ ppParen (d > 8) $ a' <+> ppString ops <> ppX expr <+> b'
EOp _ op e -> do
e' <- ppExpr' 11 val e
let ops = case operator op of
(Infix, s) -> "(" ++ s ++ ")"
(Prefix, s) -> s
- return $ showParen (d > 10) $ showString (ops ++ " ") . e'
+ return $ ppParen (d > 10) $ ppString ops <> ppX expr <+> e'
ECustom _ t1 t2 t3 a b c e1 e2 -> do
en1 <- genNameIfUsedIn t1 (IS IZ) a
@@ -196,46 +224,53 @@ ppExpr' d val = \case
c' <- ppExpr' 11 (Const dn2 `SCons` Const dn1 `SCons` SNil) c
e1' <- ppExpr' 11 val e1
e2' <- ppExpr' 11 val e2
- return $ showParen (d > 10) $ showString "custom "
- . showString ("(\\" ++ en1 ++ " " ++ en2 ++ " -> ") . a' . showString ") "
- . showString ("(\\" ++ pn1 ++ " " ++ pn2 ++ " -> ") . b' . showString ") "
- . showString ("(\\" ++ dn1 ++ " " ++ dn2 ++ " -> ") . c' . showString ") "
- . e1' . showString " "
- . e2'
+ return $ ppParen (d > 10) $
+ ppApp (ppString "custom" <> ppX expr)
+ [ppLam [ppString en1, ppString en2] a'
+ ,ppLam [ppString pn1, ppString pn2] b'
+ ,ppLam [ppString dn1, ppString dn2] c'
+ ,e1'
+ ,e2']
EWith _ e1 e2 -> do
e1' <- ppExpr' 11 val e1
name <- genNameIfUsedIn' "ac" (STAccum (typeOf e1)) IZ e2
e2' <- ppExpr' 0 (Const name `SCons` val) e2
- return $ showParen (d > 10) $
- showString "with " . e1' . showString (" (\\" ++ name ++ " -> ")
- . e2' . showString ")"
+ return $ group $ flatAlt
+ (ppParen (d > 0) $
+ hang 2 $
+ annotate AWith (ppString "with") <> ppX expr <+> e1'
+ <+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->"
+ <> hardline <> e2')
+ (ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2'])
EAccum _ i e1 e2 e3 -> do
e1' <- ppExpr' 11 val e1
e2' <- ppExpr' 11 val e2
e3' <- ppExpr' 11 val e3
- return $ showParen (d > 10) $
- showString ("accum " ++ show (fromSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3'
+ return $ ppParen (d > 10) $
+ ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (show (fromSNat i)), e1', e2', e3']
- EZero _ t -> return $ showString ("(zero :: " ++ ppTy 0 t ++ ")")
+ EZero _ t -> return $ parens $
+ annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "::" <+> ppTy' 0 t <> ppString ")"
EPlus _ _ a b -> do
a' <- ppExpr' 11 val a
b' <- ppExpr' 11 val b
- return $ showParen (d > 10) $ showString "plus " . a' . showString " " . b'
+ return $ ppParen (d > 10) $
+ ppApp (annotate AMonoid (ppString "plus") <> ppX expr) [a', b']
EOneHot _ _ i a b -> do
a' <- ppExpr' 11 val a
b' <- ppExpr' 11 val b
- return $ showParen (d > 10) $
- showString ("onehot " ++ show (fromSNat i) ++ " ") . a' . showString " " . b'
+ return $ ppParen (d > 10) $
+ ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (show (fromSNat i)), a', b']
- EError _ _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s)
+ EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s)
-ppExprLet :: Int -> SVal env -> Expr x env t -> M ShowS
+ppExprLet :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc
ppExprLet d val etop = do
- let collect :: SVal env -> Expr x env t -> M ([(String, Occ, ShowS)], ShowS)
+ let collect :: PrettyX x => SVal env -> Expr x env t -> M ([(String, Occ, ADoc)], ADoc)
collect val' (ELet _ rhs body) = do
let occ = occCount IZ body
name <- genNameIfUsedIn (typeOf rhs) IZ body
@@ -246,18 +281,24 @@ ppExprLet d val etop = do
(binds, core) <- collect val etop
- let (open, close) = case binds of
- [_] -> ("", "")
- _ -> ("{ ", " }")
- return $ showParen (d > 0) $
- showString ("let " ++ open)
- . foldr (.) id
- (intersperse (showString " ; ")
- (map (\(name, _occ, rhs) ->
- showString (name ++ {- " (" ++ show _occ ++ ")" ++ -} " = ") . rhs)
- binds))
- . showString (close ++ " in ")
- . core
+ return $ ppParen (d > 0) $
+ align $
+ annotate AKey (ppString "let")
+ <+> align (mconcat $ intersperse hardline $
+ map (\(name, _occ, rhs) ->
+ ppString (name ++ {- " (" ++ show _occ ++ ")" ++ -} " = ") <> rhs)
+ binds)
+ <> hardline <> annotate AKey (ppString "in") <+> core
+
+ppApp :: ADoc -> [ADoc] -> ADoc
+ppApp fun args = group $ fun <+> align (sep args)
+
+ppLam :: [ADoc] -> ADoc -> ADoc
+ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"])
+ <> softline <> body <> ppString ")")
+
+ppX :: PrettyX x => Expr x env t -> ADoc
+ppX expr = ppString $ prettyXsuffix (extOf expr)
data Fixity = Prefix | Infix
deriving (Show)
@@ -281,19 +322,45 @@ operator OLog{} = (Prefix, "log")
operator OIDiv{} = (Infix, "`div`")
ppTy :: Int -> STy t -> String
-ppTy d ty = ppTys d ty ""
-
-ppTys :: Int -> STy t -> ShowS
-ppTys _ STNil = showString "1"
-ppTys d (STPair a b) = showParen (d > 7) $ ppTys 8 a . showString " * " . ppTys 8 b
-ppTys d (STEither a b) = showParen (d > 6) $ ppTys 7 a . showString " + " . ppTys 7 b
-ppTys d (STMaybe t) = showParen (d > 10) $ showString "Maybe " . ppTys 11 t
-ppTys d (STArr n t) = showParen (d > 10) $
- showString "Arr " . shows (fromSNat n) . showString " " . ppTys 11 t
-ppTys _ (STScal sty) = showString $ case sty of
+ppTy d ty = render $ ppTy' d ty
+
+ppTy' :: Int -> STy t -> Doc q
+ppTy' _ STNil = ppString "1"
+ppTy' d (STPair a b) = ppParen (d > 7) $ ppTy' 8 a <> ppString " * " <> ppTy' 8 b
+ppTy' d (STEither a b) = ppParen (d > 6) $ ppTy' 7 a <> ppString " + " <> ppTy' 7 b
+ppTy' d (STMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppTy' 11 t
+ppTy' d (STArr n t) = ppParen (d > 10) $
+ ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppTy' 11 t
+ppTy' _ (STScal sty) = ppString $ case sty of
STI32 -> "i32"
STI64 -> "i64"
STF32 -> "f32"
STF64 -> "f64"
STBool -> "bool"
-ppTys d (STAccum t) = showParen (d > 10) $ showString "Accum " . ppTys 11 t
+ppTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppTy' 11 t
+
+ppString :: String -> Doc x
+ppString = fromString
+
+ppParen :: Bool -> Doc x -> Doc x
+ppParen True = parens
+ppParen False = id
+
+data Annot = AKey | AWith | AHighlight | AMonoid
+ deriving (Show)
+
+annotToANSI :: Annot -> PT.AnsiStyle
+annotToANSI AKey = PT.bold
+annotToANSI AWith = PT.color PT.Red <> PT.underlined
+annotToANSI AHighlight = PT.color PT.Blue
+annotToANSI AMonoid = PT.color PT.Green
+
+type ADoc = Doc Annot
+
+render :: Doc Annot -> String
+render =
+ (if stdoutTTY then TL.unpack . PT.renderLazy . reAnnotateS annotToANSI
+ else renderString)
+ . layoutPretty LayoutOptions { layoutPageWidth = AvailablePerLine 120 1.0 }
+ where
+ stdoutTTY = unsafePerformIO $ hSupportsANSI stdout