diff options
Diffstat (limited to 'src/AST/Pretty.hs')
-rw-r--r-- | src/AST/Pretty.hs | 225 |
1 files changed, 146 insertions, 79 deletions
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 24bacdb..4190f32 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -1,16 +1,26 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE EmptyCase #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeOperators #-} -module AST.Pretty (ppExpr, ppTy) where +module AST.Pretty (ppExpr, ppTy, PrettyX(..)) where import Control.Monad (ap) import Data.List (intersperse) import Data.Functor.Const +import Data.String (fromString) +import Prettyprinter +import Prettyprinter.Render.String + +import qualified Data.Text.Lazy as TL +import qualified Prettyprinter.Render.Terminal as PT +import System.Console.ANSI (hSupportsANSI) +import System.IO (stdout) +import System.IO.Unsafe (unsafePerformIO) import AST import AST.Count @@ -18,6 +28,17 @@ import CHAD.Types import Data +class PrettyX x where + prettyX :: x t -> String + + prettyXsuffix :: x t -> String + prettyXsuffix x = "<" ++ prettyX x ++ ">" + +instance PrettyX (Const ()) where + prettyX _ = "" + prettyXsuffix _ = "" + + type SVal = SList (Const String) newtype M a = M { runM :: Int -> (a, Int) } @@ -43,8 +64,8 @@ genNameIfUsedIn' prefix ty idx ex genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String genNameIfUsedIn = genNameIfUsedIn' "x" -ppExpr :: SList f env -> Expr x env t -> String -ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) "" +ppExpr :: PrettyX x => SList f env -> Expr x env t -> String +ppExpr senv e = render $ fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) where mkVal :: SList f env -> M (SVal env) mkVal SNil = return SNil @@ -53,34 +74,35 @@ ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) "" name <- genName return (Const name `SCons` val) -ppExpr' :: Int -> SVal env -> Expr x env t -> M ShowS -ppExpr' d val = \case - EVar _ _ i -> return $ showString $ getConst $ slistIdx val i +ppExpr' :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc +ppExpr' d val expr = case expr of + EVar _ _ i -> return $ ppString (getConst (slistIdx val i)) <> ppX expr e@ELet{} -> ppExprLet d val e EPair _ a b -> do a' <- ppExpr' 0 val a b' <- ppExpr' 0 val b - return $ showString "(" . a' . showString ", " . b' . showString ")" + return $ group $ flatAlt (align $ ppString "(" <> a' <> hardline <> ppString "," <> b' <> ppString ")" <> ppX expr) + (ppString "(" <> a' <> ppString "," <+> b' <> ppString ")" <> ppX expr) EFst _ e -> do e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "fst " . e' + return $ ppParen (d > 10) $ ppString "fst" <> ppX expr <+> e' ESnd _ e -> do e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "snd " . e' + return $ ppParen (d > 10) $ ppString "snd" <> ppX expr <+> e' - ENil _ -> return $ showString "()" + ENil _ -> return $ ppString "()" EInl _ _ e -> do e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "Inl " . e' + return $ ppParen (d > 10) $ ppString "Inl" <> ppX expr <+> e' EInr _ _ e -> do e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "Inr " . e' + return $ ppParen (d > 10) $ ppString "Inr" <> ppX expr <+> e' ECase _ e a b -> do e' <- ppExpr' 0 val e @@ -89,15 +111,17 @@ ppExpr' d val = \case a' <- ppExpr' 0 (Const name1 `SCons` val) a name2 <- genNameIfUsedIn t2 IZ b b' <- ppExpr' 0 (Const name2 `SCons` val) b - return $ showParen (d > 0) $ - showString "case " . e' . showString (" of { Inl " ++ name1 ++ " -> ") . a' - . showString (" ; Inr " ++ name2 ++ " -> ") . b' . showString " }" + return $ ppParen (d > 0) $ + hang 2 $ + annotate AKey (ppString "case") <> ppX expr <+> e' <+> annotate AKey (ppString "of") + <> hardline <> ppString "Inl" <+> ppString name1 <+> ppString "->" <+> a' + <> hardline <> ppString "Inr" <+> ppString name2 <+> ppString "->" <+> b' - ENothing _ _ -> return $ showString "nothing" + ENothing _ _ -> return $ ppString "Nothing" EJust _ e -> do e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "Just " . e' + return $ ppParen (d > 10) $ ppString "Just" <> ppX expr <+> e' EMaybe _ a b e -> do let STMaybe t = typeOf e @@ -105,18 +129,23 @@ ppExpr' d val = \case name <- genNameIfUsedIn t IZ b b' <- ppExpr' 11 (Const name `SCons` val) b e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ - showString "maybe " . a' . showString " " . b' . showString " " . e' + return $ ppParen (d > 10) $ + ppApp (ppString "maybe" <> ppX expr) [a', b', e'] EConstArr _ _ ty v - | Dict <- scalRepIsShow ty -> return $ showsPrec d v + | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr EBuild _ n a b -> do a' <- ppExpr' 11 val a name <- genNameIfUsedIn' "i" (tTup (sreplicate n tIx)) IZ b e' <- ppExpr' 0 (Const name `SCons` val) b - return $ showParen (d > 10) $ - showString "build " . a' . showString (" (\\" ++ name ++ " -> ") . e' . showString ")" + return $ group $ flatAlt + (ppParen (d > 0) $ + hang 2 $ + annotate AHighlight (ppString "build") <> ppX expr <+> a' + <+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->" + <> hardline <> e') + (ppApp (annotate AHighlight (ppString "build") <> ppX expr) [a', ppLam [ppString name] e']) EFold1Inner _ a b c -> do name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a @@ -124,65 +153,64 @@ ppExpr' d val = \case a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a b' <- ppExpr' 11 val b c' <- ppExpr' 11 val c - return $ showParen (d > 10) $ - showString ("fold1i (\\" ++ name1 ++ " " ++ name2 ++ " -> ") . a' - . showString ") " . b' . showString " " . c' + return $ ppParen (d > 10) $ + ppApp (annotate AHighlight (ppString "fold1i") <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] ESum1Inner _ e -> do e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "sum1i " . e' + return $ ppParen (d > 10) $ ppString "sum1i" <> ppX expr <+> e' EUnit _ e -> do e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "unit " . e' + return $ ppParen (d > 10) $ ppString "unit" <> ppX expr <+> e' EReplicate1Inner _ a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b - return $ showParen (d > 10) $ showString "replicate1i " . a' . showString " " . b' + return $ ppParen (d > 10) $ ppApp (ppString "replicate1i" <> ppX expr) [a', b'] EMaximum1Inner _ e -> do e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "maximum1i " . e' + return $ ppParen (d > 10) $ ppString "maximum1i" <> ppX expr <+> e' EMinimum1Inner _ e -> do e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "minimum1i " . e' + return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e' EConst _ ty v - | Dict <- scalRepIsShow ty -> return $ showsPrec d v + | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr EIdx0 _ e -> do e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "idx0 " . e' + return $ ppParen (d > 10) $ ppString "idx0" <> ppX expr <+> e' EIdx1 _ a b -> do a' <- ppExpr' 9 val a b' <- ppExpr' 9 val b - return $ showParen (d > 8) $ a' . showString " .! " . b' + return $ ppParen (d > 8) $ a' <+> ppString ".!" <> ppX expr <+> b' EIdx _ a b -> do a' <- ppExpr' 9 val a b' <- ppExpr' 10 val b - return $ showParen (d > 8) $ - a' . showString " ! " . b' + return $ ppParen (d > 8) $ + a' <+> ppString "!" <+> b' EShape _ e -> do e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "shape " . e' + return $ ppParen (d > 10) $ ppString "shape" <> ppX expr <+> e' EOp _ op (EPair _ a b) | (Infix, ops) <- operator op -> do a' <- ppExpr' 9 val a b' <- ppExpr' 9 val b - return $ showParen (d > 8) $ a' . showString (" " ++ ops ++ " ") . b' + return $ ppParen (d > 8) $ a' <+> ppString ops <> ppX expr <+> b' EOp _ op e -> do e' <- ppExpr' 11 val e let ops = case operator op of (Infix, s) -> "(" ++ s ++ ")" (Prefix, s) -> s - return $ showParen (d > 10) $ showString (ops ++ " ") . e' + return $ ppParen (d > 10) $ ppString ops <> ppX expr <+> e' ECustom _ t1 t2 t3 a b c e1 e2 -> do en1 <- genNameIfUsedIn t1 (IS IZ) a @@ -196,46 +224,53 @@ ppExpr' d val = \case c' <- ppExpr' 11 (Const dn2 `SCons` Const dn1 `SCons` SNil) c e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 - return $ showParen (d > 10) $ showString "custom " - . showString ("(\\" ++ en1 ++ " " ++ en2 ++ " -> ") . a' . showString ") " - . showString ("(\\" ++ pn1 ++ " " ++ pn2 ++ " -> ") . b' . showString ") " - . showString ("(\\" ++ dn1 ++ " " ++ dn2 ++ " -> ") . c' . showString ") " - . e1' . showString " " - . e2' + return $ ppParen (d > 10) $ + ppApp (ppString "custom" <> ppX expr) + [ppLam [ppString en1, ppString en2] a' + ,ppLam [ppString pn1, ppString pn2] b' + ,ppLam [ppString dn1, ppString dn2] c' + ,e1' + ,e2'] EWith _ e1 e2 -> do e1' <- ppExpr' 11 val e1 name <- genNameIfUsedIn' "ac" (STAccum (typeOf e1)) IZ e2 e2' <- ppExpr' 0 (Const name `SCons` val) e2 - return $ showParen (d > 10) $ - showString "with " . e1' . showString (" (\\" ++ name ++ " -> ") - . e2' . showString ")" + return $ group $ flatAlt + (ppParen (d > 0) $ + hang 2 $ + annotate AWith (ppString "with") <> ppX expr <+> e1' + <+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->" + <> hardline <> e2') + (ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2']) EAccum _ i e1 e2 e3 -> do e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 e3' <- ppExpr' 11 val e3 - return $ showParen (d > 10) $ - showString ("accum " ++ show (fromSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3' + return $ ppParen (d > 10) $ + ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (show (fromSNat i)), e1', e2', e3'] - EZero _ t -> return $ showString ("(zero :: " ++ ppTy 0 t ++ ")") + EZero _ t -> return $ parens $ + annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "::" <+> ppTy' 0 t <> ppString ")" EPlus _ _ a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b - return $ showParen (d > 10) $ showString "plus " . a' . showString " " . b' + return $ ppParen (d > 10) $ + ppApp (annotate AMonoid (ppString "plus") <> ppX expr) [a', b'] EOneHot _ _ i a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b - return $ showParen (d > 10) $ - showString ("onehot " ++ show (fromSNat i) ++ " ") . a' . showString " " . b' + return $ ppParen (d > 10) $ + ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (show (fromSNat i)), a', b'] - EError _ _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s) + EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s) -ppExprLet :: Int -> SVal env -> Expr x env t -> M ShowS +ppExprLet :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc ppExprLet d val etop = do - let collect :: SVal env -> Expr x env t -> M ([(String, Occ, ShowS)], ShowS) + let collect :: PrettyX x => SVal env -> Expr x env t -> M ([(String, Occ, ADoc)], ADoc) collect val' (ELet _ rhs body) = do let occ = occCount IZ body name <- genNameIfUsedIn (typeOf rhs) IZ body @@ -246,18 +281,24 @@ ppExprLet d val etop = do (binds, core) <- collect val etop - let (open, close) = case binds of - [_] -> ("", "") - _ -> ("{ ", " }") - return $ showParen (d > 0) $ - showString ("let " ++ open) - . foldr (.) id - (intersperse (showString " ; ") - (map (\(name, _occ, rhs) -> - showString (name ++ {- " (" ++ show _occ ++ ")" ++ -} " = ") . rhs) - binds)) - . showString (close ++ " in ") - . core + return $ ppParen (d > 0) $ + align $ + annotate AKey (ppString "let") + <+> align (mconcat $ intersperse hardline $ + map (\(name, _occ, rhs) -> + ppString (name ++ {- " (" ++ show _occ ++ ")" ++ -} " = ") <> rhs) + binds) + <> hardline <> annotate AKey (ppString "in") <+> core + +ppApp :: ADoc -> [ADoc] -> ADoc +ppApp fun args = group $ fun <+> align (sep args) + +ppLam :: [ADoc] -> ADoc -> ADoc +ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"]) + <> softline <> body <> ppString ")") + +ppX :: PrettyX x => Expr x env t -> ADoc +ppX expr = ppString $ prettyXsuffix (extOf expr) data Fixity = Prefix | Infix deriving (Show) @@ -281,19 +322,45 @@ operator OLog{} = (Prefix, "log") operator OIDiv{} = (Infix, "`div`") ppTy :: Int -> STy t -> String -ppTy d ty = ppTys d ty "" - -ppTys :: Int -> STy t -> ShowS -ppTys _ STNil = showString "1" -ppTys d (STPair a b) = showParen (d > 7) $ ppTys 8 a . showString " * " . ppTys 8 b -ppTys d (STEither a b) = showParen (d > 6) $ ppTys 7 a . showString " + " . ppTys 7 b -ppTys d (STMaybe t) = showParen (d > 10) $ showString "Maybe " . ppTys 11 t -ppTys d (STArr n t) = showParen (d > 10) $ - showString "Arr " . shows (fromSNat n) . showString " " . ppTys 11 t -ppTys _ (STScal sty) = showString $ case sty of +ppTy d ty = render $ ppTy' d ty + +ppTy' :: Int -> STy t -> Doc q +ppTy' _ STNil = ppString "1" +ppTy' d (STPair a b) = ppParen (d > 7) $ ppTy' 8 a <> ppString " * " <> ppTy' 8 b +ppTy' d (STEither a b) = ppParen (d > 6) $ ppTy' 7 a <> ppString " + " <> ppTy' 7 b +ppTy' d (STMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppTy' 11 t +ppTy' d (STArr n t) = ppParen (d > 10) $ + ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppTy' 11 t +ppTy' _ (STScal sty) = ppString $ case sty of STI32 -> "i32" STI64 -> "i64" STF32 -> "f32" STF64 -> "f64" STBool -> "bool" -ppTys d (STAccum t) = showParen (d > 10) $ showString "Accum " . ppTys 11 t +ppTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppTy' 11 t + +ppString :: String -> Doc x +ppString = fromString + +ppParen :: Bool -> Doc x -> Doc x +ppParen True = parens +ppParen False = id + +data Annot = AKey | AWith | AHighlight | AMonoid + deriving (Show) + +annotToANSI :: Annot -> PT.AnsiStyle +annotToANSI AKey = PT.bold +annotToANSI AWith = PT.color PT.Red <> PT.underlined +annotToANSI AHighlight = PT.color PT.Blue +annotToANSI AMonoid = PT.color PT.Green + +type ADoc = Doc Annot + +render :: Doc Annot -> String +render = + (if stdoutTTY then TL.unpack . PT.renderLazy . reAnnotateS annotToANSI + else renderString) + . layoutPretty LayoutOptions { layoutPageWidth = AvailablePerLine 120 1.0 } + where + stdoutTTY = unsafePerformIO $ hSupportsANSI stdout |