aboutsummaryrefslogtreecommitdiff
path: root/src/AST/Pretty.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST/Pretty.hs')
-rw-r--r--src/AST/Pretty.hs186
1 files changed, 136 insertions, 50 deletions
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 604133b..fef9686 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -7,11 +7,12 @@
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeOperators #-}
-module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppTy, PrettyX(..)) where
+module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where
import Control.Monad (ap)
-import Data.List (intersperse)
+import Data.List (intersperse, intercalate)
import Data.Functor.Const
+import qualified Data.Functor.Product as Product
import Data.String (fromString)
import Prettyprinter
import Prettyprinter.Render.String
@@ -24,6 +25,7 @@ import System.IO.Unsafe (unsafePerformIO)
import AST
import AST.Count
+import AST.Sparse.Types
import CHAD.Types
import Data
@@ -49,12 +51,20 @@ instance Monad M where { M f >>= g = M (\i -> let (x, j) = f i in runM (g x) j)
genId :: M Int
genId = M (\i -> (i, i + 1))
+nameBaseForType :: STy t -> String
+nameBaseForType STNil = "nil"
+nameBaseForType (STPair{}) = "p"
+nameBaseForType (STEither{}) = "e"
+nameBaseForType (STMaybe{}) = "m"
+nameBaseForType (STScal STI32) = "n"
+nameBaseForType (STScal STI64) = "n"
+nameBaseForType (STArr{}) = "a"
+nameBaseForType (STAccum{}) = "ac"
+nameBaseForType _ = "x"
+
genName' :: String -> M String
genName' prefix = (prefix ++) . show <$> genId
-genName :: M String
-genName = genName' "x"
-
genNameIfUsedIn' :: String -> STy a -> Idx env a -> Expr x env t -> M String
genNameIfUsedIn' prefix ty idx ex
| occCount idx ex == mempty = case ty of STNil -> return "()"
@@ -62,19 +72,27 @@ genNameIfUsedIn' prefix ty idx ex
| otherwise = genName' prefix
genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String
-genNameIfUsedIn = genNameIfUsedIn' "x"
+genNameIfUsedIn = \t -> genNameIfUsedIn' (nameBaseForType t) t
pprintExpr :: (KnownEnv env, PrettyX x) => Expr x env t -> IO ()
pprintExpr = putStrLn . ppExpr knownEnv
-ppExpr :: PrettyX x => SList f env -> Expr x env t -> String
-ppExpr senv e = render $ fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1)
+ppExpr :: PrettyX x => SList STy env -> Expr x env t -> String
+ppExpr senv e = render $ fst . flip runM 1 $ do
+ val <- mkVal senv
+ e' <- ppExpr' 0 val e
+ let lam = "λ" ++ intercalate " " (reverse (unSList (\(Product.Pair (Const name) ty) -> "(" ++ name ++ " : " ++ ppSTy 0 ty ++ ")") (slistZip val senv))) ++ "."
+ return $ group $ flatAlt
+ (hang 2 $
+ ppString lam
+ <> hardline <> e')
+ (ppString lam <+> e')
where
mkVal :: SList f env -> M (SVal env)
mkVal SNil = return SNil
mkVal (SCons _ v) = do
val <- mkVal v
- name <- genName
+ name <- genName' "arg"
return (Const name `SCons` val)
ppExpr' :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc
@@ -128,12 +146,45 @@ ppExpr' d val expr = case expr of
EMaybe _ a b e -> do
let STMaybe t = typeOf e
- a' <- ppExpr' 11 val a
+ e' <- ppExpr' 0 val e
+ a' <- ppExpr' 0 val a
name <- genNameIfUsedIn t IZ b
b' <- ppExpr' 0 (Const name `SCons` val) b
+ return $ ppParen (d > 0) $
+ align $
+ group (flatAlt
+ (annotate AKey (ppString "case") <> ppX expr <+> e'
+ <> hardline <> annotate AKey (ppString "of"))
+ (annotate AKey (ppString "case") <> ppX expr <+> e' <+> annotate AKey (ppString "of")))
+ <> hardline
+ <> indent 2
+ (ppString "Nothing" <+> ppString "->" <+> a'
+ <> hardline <> ppString "Just" <+> ppString name <+> ppString "->" <+> b')
+
+ ELNil _ _ _ -> return (ppString "LNil")
+
+ ELInl _ _ e -> do
e' <- ppExpr' 11 val e
- return $ ppParen (d > 10) $
- ppApp (ppString "maybe" <> ppX expr) [a', ppLam [ppString name] b', e']
+ return $ ppParen (d > 10) $ ppString "LInl" <> ppX expr <+> e'
+
+ ELInr _ _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppString "LInr" <> ppX expr <+> e'
+
+ ELCase _ e a b c -> do
+ e' <- ppExpr' 0 val e
+ let STLEither t1 t2 = typeOf e
+ a' <- ppExpr' 11 val a
+ name1 <- genNameIfUsedIn t1 IZ b
+ b' <- ppExpr' 0 (Const name1 `SCons` val) b
+ name2 <- genNameIfUsedIn t2 IZ c
+ c' <- ppExpr' 0 (Const name2 `SCons` val) c
+ return $ ppParen (d > 0) $
+ hang 2 $
+ annotate AKey (ppString "lcase") <> ppX expr <+> e' <+> annotate AKey (ppString "of")
+ <> hardline <> ppString "LNil" <+> ppString "->" <+> a'
+ <> hardline <> ppString "LInl" <+> ppString name1 <+> ppString "->" <+> b'
+ <> hardline <> ppString "LInr" <+> ppString name2 <+> ppString "->" <+> c'
EConstArr _ _ ty v
| Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr
@@ -142,13 +193,14 @@ ppExpr' d val expr = case expr of
a' <- ppExpr' 11 val a
name <- genNameIfUsedIn' "i" (tTup (sreplicate n tIx)) IZ b
e' <- ppExpr' 0 (Const name `SCons` val) b
+ let primName = ppString ("build" ++ intSubscript (fromSNat n))
return $ ppParen (d > 0) $
group $ flatAlt
(hang 2 $
- annotate AHighlight (ppString "build") <> ppX expr <+> a'
+ annotate AHighlight primName <> ppX expr <+> a'
<+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->"
<> hardline <> e')
- (ppApp (annotate AHighlight (ppString "build") <> ppX expr) [a', ppLam [ppString name] e'])
+ (ppApp (annotate AHighlight primName <> ppX expr) [a', ppLam [ppString name] e'])
EFold1Inner _ cm a b c -> do
name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a
@@ -237,6 +289,10 @@ ppExpr' d val expr = case expr of
,e1'
,e2']
+ ERecompute _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppApp (ppString "recompute" <> ppX expr) [e']
+
EWith _ t e1 e2 -> do
e1' <- ppExpr' 11 val e1
name <- genNameIfUsedIn' "ac" (STAccum t) IZ e2
@@ -249,27 +305,35 @@ ppExpr' d val expr = case expr of
<> hardline <> e2')
(ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2'])
- EAccum _ _ prj e1 e2 e3 -> do
+ EAccum _ t prj e1 sp e2 e3 -> do
e1' <- ppExpr' 11 val e1
e2' <- ppExpr' 11 val e2
e3' <- ppExpr' 11 val e3
return $ ppParen (d > 10) $
- ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (ppAcPrj prj), e1', e2', e3']
+ ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (applySparse sp (acPrjTy prj t)))
+ [ppString (ppAcPrj t prj), ppString (ppSparse (acPrjTy prj t) sp), e1', e2', e3']
- EZero _ t -> return $ ppParen (d > 0) $
- annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSTy' 11 t
+ EZero _ t e1 -> do
+ e1' <- ppExpr' 11 val e1
+ return $ ppParen (d > 0) $
+ annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1'
+
+ EDeepZero _ t e1 -> do
+ e1' <- ppExpr' 11 val e1
+ return $ ppParen (d > 0) $
+ annotate AMonoid (ppString "deepzero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1'
- EPlus _ _ a b -> do
+ EPlus _ t a b -> do
a' <- ppExpr' 11 val a
b' <- ppExpr' 11 val b
return $ ppParen (d > 10) $
- ppApp (annotate AMonoid (ppString "plus") <> ppX expr) [a', b']
+ ppApp (annotate AMonoid (ppString "plus") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t) [a', b']
- EOneHot _ _ prj a b -> do
+ EOneHot _ t prj a b -> do
a' <- ppExpr' 11 val a
b' <- ppExpr' 11 val b
return $ ppParen (d > 10) $
- ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (ppAcPrj prj), a', b']
+ ppApp (annotate AMonoid (ppString "onehot") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (acPrjTy prj t)) [ppString (ppAcPrj t prj), a', b']
EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s)
@@ -302,14 +366,24 @@ ppLam :: [ADoc] -> ADoc -> ADoc
ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"])
<> softline <> body <> ppString ")")
-ppAcPrj :: SAcPrj p a b -> String
-ppAcPrj SAPHere = "@"
-ppAcPrj (SAPFst prj) = "(" ++ ppAcPrj prj ++ ",)"
-ppAcPrj (SAPSnd prj) = "(," ++ ppAcPrj prj ++ ")"
-ppAcPrj (SAPLeft prj) = "(" ++ ppAcPrj prj ++ "|)"
-ppAcPrj (SAPRight prj) = "(|" ++ ppAcPrj prj ++ ")"
-ppAcPrj (SAPJust prj) = "J" ++ ppAcPrj prj
-ppAcPrj (SAPArrIdx prj n) = "[" ++ ppAcPrj prj ++ "]" ++ intSubscript (fromSNat n)
+ppAcPrj :: SMTy a -> SAcPrj p a b -> String
+ppAcPrj _ SAPHere = "."
+ppAcPrj (SMTPair t _) (SAPFst prj) = "(" ++ ppAcPrj t prj ++ ",)"
+ppAcPrj (SMTPair _ t) (SAPSnd prj) = "(," ++ ppAcPrj t prj ++ ")"
+ppAcPrj (SMTLEither t _) (SAPLeft prj) = "(" ++ ppAcPrj t prj ++ "|)"
+ppAcPrj (SMTLEither _ t) (SAPRight prj) = "(|" ++ ppAcPrj t prj ++ ")"
+ppAcPrj (SMTMaybe t) (SAPJust prj) = "J" ++ ppAcPrj t prj
+ppAcPrj (SMTArr n t) (SAPArrIdx prj) = "[" ++ ppAcPrj t prj ++ "]" ++ intSubscript (fromSNat n)
+
+ppSparse :: SMTy a -> Sparse a b -> String
+ppSparse t sp | Just Refl <- isDense t sp = "D"
+ppSparse _ SpAbsent = "A"
+ppSparse t (SpSparse s) = "S" ++ ppSparse t s
+ppSparse (SMTPair t1 t2) (SpPair s1 s2) = "(" ++ ppSparse t1 s1 ++ "," ++ ppSparse t2 s2 ++ ")"
+ppSparse (SMTLEither t1 t2) (SpLEither s1 s2) = "(" ++ ppSparse t1 s1 ++ "|" ++ ppSparse t2 s2 ++ ")"
+ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s
+ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s
+ppSparse (SMTScal _) SpScal = "."
ppX :: PrettyX x => Expr x env t -> ADoc
ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr)
@@ -334,30 +408,42 @@ operator ORecip{} = (Prefix, "recip")
operator OExp{} = (Prefix, "exp")
operator OLog{} = (Prefix, "log")
operator OIDiv{} = (Infix, "`div`")
+operator OMod{} = (Infix, "`mod`")
ppSTy :: Int -> STy t -> String
-ppSTy d ty = ppTy d (unSTy ty)
+ppSTy d ty = render $ ppSTy' d ty
ppSTy' :: Int -> STy t -> Doc q
-ppSTy' d ty = ppTy' d (unSTy ty)
-
-ppTy :: Int -> Ty -> String
-ppTy d ty = render $ ppTy' d ty
-
-ppTy' :: Int -> Ty -> Doc q
-ppTy' _ TNil = ppString "1"
-ppTy' d (TPair a b) = ppParen (d > 7) $ ppTy' 8 a <> ppString " * " <> ppTy' 8 b
-ppTy' d (TEither a b) = ppParen (d > 6) $ ppTy' 7 a <> ppString " + " <> ppTy' 7 b
-ppTy' d (TMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppTy' 11 t
-ppTy' d (TArr n t) = ppParen (d > 10) $
- ppString "Arr " <> ppString (show (fromNat n)) <> ppString " " <> ppTy' 11 t
-ppTy' _ (TScal sty) = ppString $ case sty of
- TI32 -> "i32"
- TI64 -> "i64"
- TF32 -> "f32"
- TF64 -> "f64"
- TBool -> "bool"
-ppTy' d (TAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppTy' 11 t
+ppSTy' _ STNil = ppString "1"
+ppSTy' d (STPair a b) = ppParen (d > 7) $ ppSTy' 8 a <> ppString " * " <> ppSTy' 8 b
+ppSTy' d (STEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " + " <> ppSTy' 7 b
+ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b
+ppSTy' d (STMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSTy' 11 t
+ppSTy' d (STArr n t) = ppParen (d > 10) $
+ ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSTy' 11 t
+ppSTy' _ (STScal sty) = ppString $ case sty of
+ STI32 -> "i32"
+ STI64 -> "i64"
+ STF32 -> "f32"
+ STF64 -> "f64"
+ STBool -> "bool"
+ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSMTy' 11 t
+
+ppSMTy :: Int -> SMTy t -> String
+ppSMTy d ty = render $ ppSMTy' d ty
+
+ppSMTy' :: Int -> SMTy t -> Doc q
+ppSMTy' _ SMTNil = ppString "1"
+ppSMTy' d (SMTPair a b) = ppParen (d > 7) $ ppSMTy' 8 a <> ppString " * " <> ppSMTy' 8 b
+ppSMTy' d (SMTLEither a b) = ppParen (d > 6) $ ppSMTy' 7 a <> ppString " ⊕ " <> ppSMTy' 7 b
+ppSMTy' d (SMTMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSMTy' 11 t
+ppSMTy' d (SMTArr n t) = ppParen (d > 10) $
+ ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSMTy' 11 t
+ppSMTy' _ (SMTScal sty) = ppString $ case sty of
+ STI32 -> "i32"
+ STI64 -> "i64"
+ STF32 -> "f32"
+ STF64 -> "f64"
ppString :: String -> Doc x
ppString = fromString