{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeOperators #-} module AST.Pretty (ppExpr, ppSTy, ppTy, PrettyX(..)) where import Control.Monad (ap) import Data.List (intersperse) import Data.Functor.Const import Data.String (fromString) import Prettyprinter import Prettyprinter.Render.String import qualified Data.Text.Lazy as TL import qualified Prettyprinter.Render.Terminal as PT import System.Console.ANSI (hSupportsANSI) import System.IO (stdout) import System.IO.Unsafe (unsafePerformIO) import AST import AST.Count import CHAD.Types import Data class PrettyX x where prettyX :: x t -> String prettyXsuffix :: x t -> String prettyXsuffix x = "<" ++ prettyX x ++ ">" instance PrettyX (Const ()) where prettyX _ = "" prettyXsuffix _ = "" type SVal = SList (Const String) newtype M a = M { runM :: Int -> (a, Int) } deriving (Functor) instance Applicative M where { pure x = M (\i -> (x, i)) ; (<*>) = ap } instance Monad M where { M f >>= g = M (\i -> let (x, j) = f i in runM (g x) j) } genId :: M Int genId = M (\i -> (i, i + 1)) genName' :: String -> M String genName' prefix = (prefix ++) . show <$> genId genName :: M String genName = genName' "x" genNameIfUsedIn' :: String -> STy a -> Idx env a -> Expr x env t -> M String genNameIfUsedIn' prefix ty idx ex | occCount idx ex == mempty = case ty of STNil -> return "()" _ -> return "_" | otherwise = genName' prefix genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String genNameIfUsedIn = genNameIfUsedIn' "x" ppExpr :: PrettyX x => SList f env -> Expr x env t -> String ppExpr senv e = render $ fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) where mkVal :: SList f env -> M (SVal env) mkVal SNil = return SNil mkVal (SCons _ v) = do val <- mkVal v name <- genName return (Const name `SCons` val) ppExpr' :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc ppExpr' d val expr = case expr of EVar _ _ i -> return $ ppString (getConst (slistIdx val i)) <> ppX expr e@ELet{} -> ppExprLet d val e EPair _ a b -> do a' <- ppExpr' 0 val a b' <- ppExpr' 0 val b return $ group $ flatAlt (align $ ppString "(" <> a' <> hardline <> ppString "," <> b' <> ppString ")" <> ppX expr) (ppString "(" <> a' <> ppString "," <+> b' <> ppString ")" <> ppX expr) EFst _ e -> do e' <- ppExpr' 11 val e return $ ppParen (d > 10) $ ppString "fst" <> ppX expr <+> e' ESnd _ e -> do e' <- ppExpr' 11 val e return $ ppParen (d > 10) $ ppString "snd" <> ppX expr <+> e' ENil _ -> return $ ppString "()" EInl _ _ e -> do e' <- ppExpr' 11 val e return $ ppParen (d > 10) $ ppString "Inl" <> ppX expr <+> e' EInr _ _ e -> do e' <- ppExpr' 11 val e return $ ppParen (d > 10) $ ppString "Inr" <> ppX expr <+> e' ECase _ e a b -> do e' <- ppExpr' 0 val e let STEither t1 t2 = typeOf e name1 <- genNameIfUsedIn t1 IZ a a' <- ppExpr' 0 (Const name1 `SCons` val) a name2 <- genNameIfUsedIn t2 IZ b b' <- ppExpr' 0 (Const name2 `SCons` val) b return $ ppParen (d > 0) $ hang 2 $ annotate AKey (ppString "case") <> ppX expr <+> e' <+> annotate AKey (ppString "of") <> hardline <> ppString "Inl" <+> ppString name1 <+> ppString "->" <+> a' <> hardline <> ppString "Inr" <+> ppString name2 <+> ppString "->" <+> b' ENothing _ _ -> return $ ppString "Nothing" EJust _ e -> do e' <- ppExpr' 11 val e return $ ppParen (d > 10) $ ppString "Just" <> ppX expr <+> e' EMaybe _ a b e -> do let STMaybe t = typeOf e a' <- ppExpr' 11 val a name <- genNameIfUsedIn t IZ b b' <- ppExpr' 11 (Const name `SCons` val) b e' <- ppExpr' 11 val e return $ ppParen (d > 10) $ ppApp (ppString "maybe" <> ppX expr) [a', b', e'] EConstArr _ _ ty v | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr EBuild _ n a b -> do a' <- ppExpr' 11 val a name <- genNameIfUsedIn' "i" (tTup (sreplicate n tIx)) IZ b e' <- ppExpr' 0 (Const name `SCons` val) b return $ group $ flatAlt (ppParen (d > 0) $ hang 2 $ annotate AHighlight (ppString "build") <> ppX expr <+> a' <+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->" <> hardline <> e') (ppApp (annotate AHighlight (ppString "build") <> ppX expr) [a', ppLam [ppString name] e']) EFold1Inner _ a b c -> do name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a name2 <- genNameIfUsedIn (typeOf a) IZ a a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a b' <- ppExpr' 11 val b c' <- ppExpr' 11 val c return $ ppParen (d > 10) $ ppApp (annotate AHighlight (ppString "fold1i") <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] ESum1Inner _ e -> do e' <- ppExpr' 11 val e return $ ppParen (d > 10) $ ppString "sum1i" <> ppX expr <+> e' EUnit _ e -> do e' <- ppExpr' 11 val e return $ ppParen (d > 10) $ ppString "unit" <> ppX expr <+> e' EReplicate1Inner _ a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b return $ ppParen (d > 10) $ ppApp (ppString "replicate1i" <> ppX expr) [a', b'] EMaximum1Inner _ e -> do e' <- ppExpr' 11 val e return $ ppParen (d > 10) $ ppString "maximum1i" <> ppX expr <+> e' EMinimum1Inner _ e -> do e' <- ppExpr' 11 val e return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e' EConst _ ty v | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr EIdx0 _ e -> do e' <- ppExpr' 11 val e return $ ppParen (d > 10) $ ppString "idx0" <> ppX expr <+> e' EIdx1 _ a b -> do a' <- ppExpr' 9 val a b' <- ppExpr' 9 val b return $ ppParen (d > 8) $ a' <+> ppString ".!" <> ppX expr <+> b' EIdx _ a b -> do a' <- ppExpr' 9 val a b' <- ppExpr' 10 val b return $ ppParen (d > 8) $ a' <+> ppString "!" <+> b' EShape _ e -> do e' <- ppExpr' 11 val e return $ ppParen (d > 10) $ ppString "shape" <> ppX expr <+> e' EOp _ op (EPair _ a b) | (Infix, ops) <- operator op -> do a' <- ppExpr' 9 val a b' <- ppExpr' 9 val b return $ ppParen (d > 8) $ a' <+> ppString ops <> ppX expr <+> b' EOp _ op e -> do e' <- ppExpr' 11 val e let ops = case operator op of (Infix, s) -> "(" ++ s ++ ")" (Prefix, s) -> s return $ ppParen (d > 10) $ ppString ops <> ppX expr <+> e' ECustom _ t1 t2 t3 a b c e1 e2 -> do en1 <- genNameIfUsedIn t1 (IS IZ) a en2 <- genNameIfUsedIn t2 IZ a pn1 <- genNameIfUsedIn (d1 t1) (IS IZ) b pn2 <- genNameIfUsedIn (d1 t2) IZ b dn1 <- genNameIfUsedIn' "tape" t3 (IS IZ) c dn2 <- genNameIfUsedIn' "d" (d2 (typeOf a)) IZ c a' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) a b' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) b c' <- ppExpr' 11 (Const dn2 `SCons` Const dn1 `SCons` SNil) c e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 return $ ppParen (d > 10) $ ppApp (ppString "custom" <> ppX expr) [ppLam [ppString en1, ppString en2] a' ,ppLam [ppString pn1, ppString pn2] b' ,ppLam [ppString dn1, ppString dn2] c' ,e1' ,e2'] EWith _ e1 e2 -> do e1' <- ppExpr' 11 val e1 name <- genNameIfUsedIn' "ac" (STAccum (typeOf e1)) IZ e2 e2' <- ppExpr' 0 (Const name `SCons` val) e2 return $ group $ flatAlt (ppParen (d > 0) $ hang 2 $ annotate AWith (ppString "with") <> ppX expr <+> e1' <+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->" <> hardline <> e2') (ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2']) EAccum _ i e1 e2 e3 -> do e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 e3' <- ppExpr' 11 val e3 return $ ppParen (d > 10) $ ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (show (fromSNat i)), e1', e2', e3'] EZero _ t -> return $ parens $ annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "::" <+> ppSTy' 0 t <> ppString ")" EPlus _ _ a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b return $ ppParen (d > 10) $ ppApp (annotate AMonoid (ppString "plus") <> ppX expr) [a', b'] EOneHot _ _ i a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b return $ ppParen (d > 10) $ ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (show (fromSNat i)), a', b'] EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s) ppExprLet :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc ppExprLet d val etop = do let collect :: PrettyX x => SVal env -> Expr x env t -> M ([(String, Occ, ADoc)], ADoc) collect val' (ELet _ rhs body) = do let occ = occCount IZ body name <- genNameIfUsedIn (typeOf rhs) IZ body rhs' <- ppExpr' 0 val' rhs (binds, core) <- collect (Const name `SCons` val') body return ((name, occ, rhs') : binds, core) collect val' e = ([],) <$> ppExpr' 0 val' e (binds, core) <- collect val etop return $ ppParen (d > 0) $ align $ annotate AKey (ppString "let") <+> align (mconcat $ intersperse hardline $ map (\(name, _occ, rhs) -> ppString (name ++ {- " (" ++ show _occ ++ ")" ++ -} " = ") <> rhs) binds) <> hardline <> annotate AKey (ppString "in") <+> core ppApp :: ADoc -> [ADoc] -> ADoc ppApp fun args = group $ fun <+> align (sep args) ppLam :: [ADoc] -> ADoc -> ADoc ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"]) <> softline <> body <> ppString ")") ppX :: PrettyX x => Expr x env t -> ADoc ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr) data Fixity = Prefix | Infix deriving (Show) operator :: SOp a t -> (Fixity, String) operator OAdd{} = (Infix, "+") operator OMul{} = (Infix, "*") operator ONeg{} = (Prefix, "negate") operator OLt{} = (Infix, "<") operator OLe{} = (Infix, "<=") operator OEq{} = (Infix, "==") operator ONot = (Prefix, "not") operator OAnd = (Infix, "&&") operator OOr = (Infix, "||") operator OIf = (Prefix, "ifB") operator ORound64 = (Prefix, "round") operator OToFl64 = (Prefix, "toFl64") operator ORecip{} = (Prefix, "recip") operator OExp{} = (Prefix, "exp") operator OLog{} = (Prefix, "log") operator OIDiv{} = (Infix, "`div`") ppSTy :: Int -> STy t -> String ppSTy d ty = ppTy d (unSTy ty) ppSTy' :: Int -> STy t -> Doc q ppSTy' d ty = ppTy' d (unSTy ty) ppTy :: Int -> Ty -> String ppTy d ty = render $ ppTy' d ty ppTy' :: Int -> Ty -> Doc q ppTy' _ TNil = ppString "1" ppTy' d (TPair a b) = ppParen (d > 7) $ ppTy' 8 a <> ppString " * " <> ppTy' 8 b ppTy' d (TEither a b) = ppParen (d > 6) $ ppTy' 7 a <> ppString " + " <> ppTy' 7 b ppTy' d (TMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppTy' 11 t ppTy' d (TArr n t) = ppParen (d > 10) $ ppString "Arr " <> ppString (show (fromNat n)) <> ppString " " <> ppTy' 11 t ppTy' _ (TScal sty) = ppString $ case sty of TI32 -> "i32" TI64 -> "i64" TF32 -> "f32" TF64 -> "f64" TBool -> "bool" ppTy' d (TAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppTy' 11 t ppString :: String -> Doc x ppString = fromString ppParen :: Bool -> Doc x -> Doc x ppParen True = parens ppParen False = id data Annot = AKey | AWith | AHighlight | AMonoid | AExt deriving (Show) annotToANSI :: Annot -> PT.AnsiStyle annotToANSI AKey = PT.bold annotToANSI AWith = PT.color PT.Red <> PT.underlined annotToANSI AHighlight = PT.color PT.Blue annotToANSI AMonoid = PT.color PT.Green annotToANSI AExt = mempty type ADoc = Doc Annot render :: Doc Annot -> String render = (if stdoutTTY then TL.unpack . PT.renderLazy . reAnnotateS annotToANSI else renderString) . layoutPretty LayoutOptions { layoutPageWidth = AvailablePerLine 120 1.0 } where stdoutTTY = unsafePerformIO $ hSupportsANSI stdout