aboutsummaryrefslogtreecommitdiff
path: root/Pretty.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2021-06-27 18:34:35 +0200
committerTom Smeding <tom@tomsmeding.com>2021-06-27 18:34:35 +0200
commitd4abcc3b2dfefbbcb7cd4a182eec64f1da42d951 (patch)
tree1ab301617043ac6df228ef617afa22633a01a671 /Pretty.hs
parent0fefe4822c65bde81ec4c0da1b5b32a9b411ca79 (diff)
Diffstat (limited to 'Pretty.hs')
-rw-r--r--Pretty.hs229
1 files changed, 229 insertions, 0 deletions
diff --git a/Pretty.hs b/Pretty.hs
new file mode 100644
index 0000000..d63e3ce
--- /dev/null
+++ b/Pretty.hs
@@ -0,0 +1,229 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TypeOperators #-}
+module Pretty (
+ prettyExp,
+ pprintExp,
+) where
+
+import Data.Bifunctor
+import Prettyprinter
+import Prettyprinter.Render.String
+
+import AST
+
+
+newtype IdGen a = IdGen { runIdGen :: Int -> (a, Int) }
+instance Functor IdGen where
+ fmap f (IdGen g) = IdGen (first f . g)
+instance Applicative IdGen where
+ pure x = IdGen (x,)
+ IdGen f <*> IdGen g = IdGen (\i -> let (f', j) = f i in first f' (g j))
+instance Monad IdGen where
+ IdGen f >>= g = IdGen (\i -> let (x, j) = f i in runIdGen (g x) j)
+
+evalIdGen :: Int -> IdGen a -> a
+evalIdGen i = fst . ($ i) . runIdGen
+
+genId :: IdGen Int
+genId = IdGen (\i -> (i, i + 1))
+
+genName :: IdGen String
+genName = ('x' :) . show <$> genId
+
+data PEnv env where
+ Top :: PEnv env
+ PCons :: String -> PEnv env -> PEnv (a ': env)
+
+prettyExp :: Exp env a -> String
+prettyExp e = renderString (layoutSmart opts (evalIdGen 1 (pExp definfo 0 Top e)))
+ where opts = LayoutOptions (AvailablePerLine 120 0.7)
+
+pprintExp :: Exp env a -> IO ()
+pprintExp = putStrLn . prettyExp
+
+data Info = Info
+ { infoLamTypeSig :: Bool }
+ deriving (Show)
+
+definfo :: Info
+definfo = Info True
+
+pExp :: forall env a x. Info -> Int -> PEnv env -> Exp env a -> IdGen (Doc x)
+pExp thisinfo d env = \case
+ App (Const CAddF) (Pair a b) -> do
+ a' <- pExp definfo 7 env a
+ b' <- pExp definfo 7 env b
+ return (flatAlt (pParen (d > 10) $ pretty "AddF" <+> align (vsep [a', b']))
+ (pParen (d > 6) $ hsep [a', pretty "+", b']))
+
+ App (Const CMulF) (Pair a b) -> do
+ a' <- pExp definfo 8 env a
+ b' <- pExp definfo 8 env b
+ return (flatAlt (pParen (d > 10) $ pretty "MulF" <+> align (vsep [a', b']))
+ (pParen (d > 7) $ hsep [a', pretty "*", b']))
+
+ e@(App _ _) -> do
+ let collectAppsRev :: Exp env t -> IdGen (Doc x', [Doc x'])
+ collectAppsRev (App f a) = do
+ a' <- pExp definfo 11 env a
+ rest <- collectAppsRev f
+ return (fmap (a' :) rest)
+ collectAppsRev f = (,[]) <$> pExp definfo 11 env f
+ (func, rhss) <- collectAppsRev e
+ return (pParen (d > 10) $ func <+> align (sep (reverse rhss)))
+
+ Lam t e -> do
+ name <- genName
+ let prefix | infoLamTypeSig thisinfo =
+ pretty ("\\(" ++ name ++ " :: " ++ showType 0 t ") ->")
+ | otherwise =
+ pretty ("\\" ++ name ++ " ->")
+ body <- pExp definfo 0 (PCons name env) e
+ return (pParen (d > 0) $ nest 2 (sep [prefix, body]))
+
+ Var t i ->
+ case (env, i) of
+ (Top, _) -> return (pretty ("xUP_" ++ show (idxToInt i)))
+ (PCons name _, Zero) -> return (pretty name)
+ (PCons _ env', Succ i') -> pExp definfo d env' (Var t i')
+
+ e@(Let _ _) -> do
+ let collectLets :: PEnv env' -> Exp env' t -> IdGen (Doc x', [Doc x'])
+ collectLets env' (Let rhs body) = do
+ name <- genName
+ rhs' <- (pretty (name ++ " = ") <>) . group <$> pExp definfo 0 env' rhs
+ rest <- collectLets (PCons name env') body
+ return (fmap (rhs' :) rest)
+ collectLets env' f = (,[]) <$> pExp definfo 0 env' f
+ (core, rhss) <- collectLets env e
+ return (pParen (d > 0) $
+ align (vsep [pretty "let" <+> align (vsep rhss)
+ ,pretty "in" <+> group core]))
+
+ Lit l -> return (pretty (showLit d l ""))
+
+ Cond e1 e2 e3 -> do
+ e1' <- pExp definfo 11 env e1
+ e2' <- pExp definfo 11 env e2
+ e3' <- pExp definfo 11 env e3
+ return (flatAlt (pParen (d > 10) $ pretty "cond" <+> align (vsep [e1', e2', e3']))
+ (pParen (d > 0) $ hsep [e1', pretty "?", e2', pretty ":", e3']))
+
+ Const c -> return (pretty (showConst c))
+
+ Pair e1 e2 -> do
+ e1' <- pExp definfo 0 env e1
+ e2' <- pExp definfo 0 env e2
+ return (tupled [e1', e2'])
+
+ Fst e -> do
+ e' <- pExp definfo 11 env e
+ return (pParen (d > 10) $ pretty "fst" <+> e')
+
+ Snd e -> do
+ e' <- pExp definfo 11 env e
+ return (pParen (d > 10) $ pretty "snd" <+> e')
+
+ Build sht e1 e2 -> do
+ e1' <- pExp definfo 11 env e1
+ e2' <- pExp definfo{infoLamTypeSig=False} 11 env e2
+ return (pParen (d > 10) $
+ pretty "build" <+> align (sep
+ [pretty ("DIM" <> show (shtToInt sht)), e1', e2']))
+
+ Ifold sht e1 e2 e3 -> do
+ e1' <- pExp (definfo{infoLamTypeSig=False}) 11 env e1
+ e2' <- pExp definfo 11 env e2
+ e3' <- pExp definfo 11 env e3
+ return (pParen (d > 10) $
+ pretty "ifold" <+> align (sep
+ [pretty ("DIM" <> show (shtToInt sht)), e1', e2', e3']))
+
+ Index e1 e2 -> do
+ e1' <- pExp definfo 11 env e1
+ e2' <- pExp definfo 11 env e2
+ return (pParen (d > 10) $
+ flatAlt (pretty "index" <+> align (sep [e1', e2']))
+ (hsep [e1', pretty "!", e2']))
+
+ Shape e -> do
+ e' <- pExp definfo 11 env e
+ return (pParen (d > 10) $ pretty "shape" <+> e')
+
+ Undef t -> return (pParen (d > 0) $ pretty ("UNDEF :: " ++ showType 0 t ""))
+
+pParen :: Bool -> Doc x -> Doc x
+pParen True = parens
+pParen False = id
+
+showLit :: Int -> Literal a -> ShowS
+showLit _ (LInt i) = shows i
+showLit _ (LBool b) = shows b
+showLit _ (LDouble d) = shows d
+showLit d (LArray (Array sh t v))
+ | Just Has <- typeHasShow t
+ = showParen (d > 0) $
+ shows v . showString " :: Array " . showShape 11 sh . showString " " . showType 11 t
+ | otherwise
+ = showParen (d > 0) $
+ showString "[{noshow}] :: Array " . showShape 11 sh . showString " " . showType 11 t
+showLit d (LShape sh) = showShape d sh
+showLit _ LNil = showString "()"
+showLit _ (LPair a b) =
+ showString "(" . showLit 0 a . showString ", " . showLit 0 b . showString ")"
+
+showConst :: Constant a -> String
+showConst CAddI = "AddI"
+showConst CSubI = "SubI"
+showConst CMulI = "MulI"
+showConst CDivI = "DivI"
+showConst CAddF = "AddF"
+showConst CSubF = "SubF"
+showConst CMulF = "MulF"
+showConst CDivF = "DivF"
+showConst CLog = "Log"
+showConst CExp = "Exp"
+showConst CtoF = "ToF"
+showConst CRound = "Round"
+showConst CLtI = "LtI"
+showConst CLeI = "LeI"
+showConst CLtF = "LtF"
+showConst (CEq _) = "Eq"
+showConst CAnd = "And"
+showConst COr = "Or"
+showConst CNot = "Not"
+
+showShape :: Int -> Shape sh -> ShowS
+showShape _ Z = showString "Z"
+showShape d (sh :. n) = showParen (d > 10) $
+ showShape 10 sh . showString ":" . shows n
+
+showType :: Int -> Type a -> ShowS
+showType _ TInt = showString "Int"
+showType _ TBool = showString "Bool"
+showType _ TDouble = showString "Double"
+showType _ (TArray sht t) =
+ let n = shtToInt sht
+ in showString (replicate n '[') . showType 0 t . showString (replicate n ']')
+showType _ TNil = showString "()"
+showType _ (TPair a b) =
+ showString "(" . showType 0 a . showString ", " . showType 0 b . showString ")"
+showType d (TFun a b) = showParen (d > 10) $
+ showType 11 a . showString " -> " . showType 10 b
+
+-- showTypeShort :: Int -> Type a -> ShowS
+-- showTypeShort _ TInt = showString "i"
+-- showTypeShort _ TBool = showString "b"
+-- showTypeShort _ TDouble = showString "d"
+-- showTypeShort _ (TArray sht t) =
+-- let n = shtToInt sht
+-- in showString (replicate n '[') . showTypeShort 0 t . showString (replicate n ']')
+-- showTypeShort _ TNil = showString "."
+-- showTypeShort _ (TPair a b) =
+-- showString "(" . showTypeShort 11 a . showTypeShort 11 b . showString ")"
+-- showTypeShort d (TFun a b) = showParen (d > 10) $
+-- showTypeShort 11 a . showString " -> " . showTypeShort 10 b