diff options
Diffstat (limited to 'src/AST/Pretty.hs')
-rw-r--r-- | src/AST/Pretty.hs | 186 |
1 files changed, 136 insertions, 50 deletions
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 604133b..fef9686 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -7,11 +7,12 @@ {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeOperators #-} -module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppTy, PrettyX(..)) where +module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where import Control.Monad (ap) -import Data.List (intersperse) +import Data.List (intersperse, intercalate) import Data.Functor.Const +import qualified Data.Functor.Product as Product import Data.String (fromString) import Prettyprinter import Prettyprinter.Render.String @@ -24,6 +25,7 @@ import System.IO.Unsafe (unsafePerformIO) import AST import AST.Count +import AST.Sparse.Types import CHAD.Types import Data @@ -49,12 +51,20 @@ instance Monad M where { M f >>= g = M (\i -> let (x, j) = f i in runM (g x) j) genId :: M Int genId = M (\i -> (i, i + 1)) +nameBaseForType :: STy t -> String +nameBaseForType STNil = "nil" +nameBaseForType (STPair{}) = "p" +nameBaseForType (STEither{}) = "e" +nameBaseForType (STMaybe{}) = "m" +nameBaseForType (STScal STI32) = "n" +nameBaseForType (STScal STI64) = "n" +nameBaseForType (STArr{}) = "a" +nameBaseForType (STAccum{}) = "ac" +nameBaseForType _ = "x" + genName' :: String -> M String genName' prefix = (prefix ++) . show <$> genId -genName :: M String -genName = genName' "x" - genNameIfUsedIn' :: String -> STy a -> Idx env a -> Expr x env t -> M String genNameIfUsedIn' prefix ty idx ex | occCount idx ex == mempty = case ty of STNil -> return "()" @@ -62,19 +72,27 @@ genNameIfUsedIn' prefix ty idx ex | otherwise = genName' prefix genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String -genNameIfUsedIn = genNameIfUsedIn' "x" +genNameIfUsedIn = \t -> genNameIfUsedIn' (nameBaseForType t) t pprintExpr :: (KnownEnv env, PrettyX x) => Expr x env t -> IO () pprintExpr = putStrLn . ppExpr knownEnv -ppExpr :: PrettyX x => SList f env -> Expr x env t -> String -ppExpr senv e = render $ fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) +ppExpr :: PrettyX x => SList STy env -> Expr x env t -> String +ppExpr senv e = render $ fst . flip runM 1 $ do + val <- mkVal senv + e' <- ppExpr' 0 val e + let lam = "λ" ++ intercalate " " (reverse (unSList (\(Product.Pair (Const name) ty) -> "(" ++ name ++ " : " ++ ppSTy 0 ty ++ ")") (slistZip val senv))) ++ "." + return $ group $ flatAlt + (hang 2 $ + ppString lam + <> hardline <> e') + (ppString lam <+> e') where mkVal :: SList f env -> M (SVal env) mkVal SNil = return SNil mkVal (SCons _ v) = do val <- mkVal v - name <- genName + name <- genName' "arg" return (Const name `SCons` val) ppExpr' :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc @@ -128,12 +146,45 @@ ppExpr' d val expr = case expr of EMaybe _ a b e -> do let STMaybe t = typeOf e - a' <- ppExpr' 11 val a + e' <- ppExpr' 0 val e + a' <- ppExpr' 0 val a name <- genNameIfUsedIn t IZ b b' <- ppExpr' 0 (Const name `SCons` val) b + return $ ppParen (d > 0) $ + align $ + group (flatAlt + (annotate AKey (ppString "case") <> ppX expr <+> e' + <> hardline <> annotate AKey (ppString "of")) + (annotate AKey (ppString "case") <> ppX expr <+> e' <+> annotate AKey (ppString "of"))) + <> hardline + <> indent 2 + (ppString "Nothing" <+> ppString "->" <+> a' + <> hardline <> ppString "Just" <+> ppString name <+> ppString "->" <+> b') + + ELNil _ _ _ -> return (ppString "LNil") + + ELInl _ _ e -> do e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ - ppApp (ppString "maybe" <> ppX expr) [a', ppLam [ppString name] b', e'] + return $ ppParen (d > 10) $ ppString "LInl" <> ppX expr <+> e' + + ELInr _ _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppString "LInr" <> ppX expr <+> e' + + ELCase _ e a b c -> do + e' <- ppExpr' 0 val e + let STLEither t1 t2 = typeOf e + a' <- ppExpr' 11 val a + name1 <- genNameIfUsedIn t1 IZ b + b' <- ppExpr' 0 (Const name1 `SCons` val) b + name2 <- genNameIfUsedIn t2 IZ c + c' <- ppExpr' 0 (Const name2 `SCons` val) c + return $ ppParen (d > 0) $ + hang 2 $ + annotate AKey (ppString "lcase") <> ppX expr <+> e' <+> annotate AKey (ppString "of") + <> hardline <> ppString "LNil" <+> ppString "->" <+> a' + <> hardline <> ppString "LInl" <+> ppString name1 <+> ppString "->" <+> b' + <> hardline <> ppString "LInr" <+> ppString name2 <+> ppString "->" <+> c' EConstArr _ _ ty v | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr @@ -142,13 +193,14 @@ ppExpr' d val expr = case expr of a' <- ppExpr' 11 val a name <- genNameIfUsedIn' "i" (tTup (sreplicate n tIx)) IZ b e' <- ppExpr' 0 (Const name `SCons` val) b + let primName = ppString ("build" ++ intSubscript (fromSNat n)) return $ ppParen (d > 0) $ group $ flatAlt (hang 2 $ - annotate AHighlight (ppString "build") <> ppX expr <+> a' + annotate AHighlight primName <> ppX expr <+> a' <+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->" <> hardline <> e') - (ppApp (annotate AHighlight (ppString "build") <> ppX expr) [a', ppLam [ppString name] e']) + (ppApp (annotate AHighlight primName <> ppX expr) [a', ppLam [ppString name] e']) EFold1Inner _ cm a b c -> do name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a @@ -237,6 +289,10 @@ ppExpr' d val expr = case expr of ,e1' ,e2'] + ERecompute _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppApp (ppString "recompute" <> ppX expr) [e'] + EWith _ t e1 e2 -> do e1' <- ppExpr' 11 val e1 name <- genNameIfUsedIn' "ac" (STAccum t) IZ e2 @@ -249,27 +305,35 @@ ppExpr' d val expr = case expr of <> hardline <> e2') (ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2']) - EAccum _ _ prj e1 e2 e3 -> do + EAccum _ t prj e1 sp e2 e3 -> do e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 e3' <- ppExpr' 11 val e3 return $ ppParen (d > 10) $ - ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (ppAcPrj prj), e1', e2', e3'] + ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (applySparse sp (acPrjTy prj t))) + [ppString (ppAcPrj t prj), ppString (ppSparse (acPrjTy prj t) sp), e1', e2', e3'] - EZero _ t -> return $ ppParen (d > 0) $ - annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSTy' 11 t + EZero _ t e1 -> do + e1' <- ppExpr' 11 val e1 + return $ ppParen (d > 0) $ + annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' + + EDeepZero _ t e1 -> do + e1' <- ppExpr' 11 val e1 + return $ ppParen (d > 0) $ + annotate AMonoid (ppString "deepzero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' - EPlus _ _ a b -> do + EPlus _ t a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b return $ ppParen (d > 10) $ - ppApp (annotate AMonoid (ppString "plus") <> ppX expr) [a', b'] + ppApp (annotate AMonoid (ppString "plus") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t) [a', b'] - EOneHot _ _ prj a b -> do + EOneHot _ t prj a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b return $ ppParen (d > 10) $ - ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (ppAcPrj prj), a', b'] + ppApp (annotate AMonoid (ppString "onehot") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (acPrjTy prj t)) [ppString (ppAcPrj t prj), a', b'] EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s) @@ -302,14 +366,24 @@ ppLam :: [ADoc] -> ADoc -> ADoc ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"]) <> softline <> body <> ppString ")") -ppAcPrj :: SAcPrj p a b -> String -ppAcPrj SAPHere = "@" -ppAcPrj (SAPFst prj) = "(" ++ ppAcPrj prj ++ ",)" -ppAcPrj (SAPSnd prj) = "(," ++ ppAcPrj prj ++ ")" -ppAcPrj (SAPLeft prj) = "(" ++ ppAcPrj prj ++ "|)" -ppAcPrj (SAPRight prj) = "(|" ++ ppAcPrj prj ++ ")" -ppAcPrj (SAPJust prj) = "J" ++ ppAcPrj prj -ppAcPrj (SAPArrIdx prj n) = "[" ++ ppAcPrj prj ++ "]" ++ intSubscript (fromSNat n) +ppAcPrj :: SMTy a -> SAcPrj p a b -> String +ppAcPrj _ SAPHere = "." +ppAcPrj (SMTPair t _) (SAPFst prj) = "(" ++ ppAcPrj t prj ++ ",)" +ppAcPrj (SMTPair _ t) (SAPSnd prj) = "(," ++ ppAcPrj t prj ++ ")" +ppAcPrj (SMTLEither t _) (SAPLeft prj) = "(" ++ ppAcPrj t prj ++ "|)" +ppAcPrj (SMTLEither _ t) (SAPRight prj) = "(|" ++ ppAcPrj t prj ++ ")" +ppAcPrj (SMTMaybe t) (SAPJust prj) = "J" ++ ppAcPrj t prj +ppAcPrj (SMTArr n t) (SAPArrIdx prj) = "[" ++ ppAcPrj t prj ++ "]" ++ intSubscript (fromSNat n) + +ppSparse :: SMTy a -> Sparse a b -> String +ppSparse t sp | Just Refl <- isDense t sp = "D" +ppSparse _ SpAbsent = "A" +ppSparse t (SpSparse s) = "S" ++ ppSparse t s +ppSparse (SMTPair t1 t2) (SpPair s1 s2) = "(" ++ ppSparse t1 s1 ++ "," ++ ppSparse t2 s2 ++ ")" +ppSparse (SMTLEither t1 t2) (SpLEither s1 s2) = "(" ++ ppSparse t1 s1 ++ "|" ++ ppSparse t2 s2 ++ ")" +ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s +ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s +ppSparse (SMTScal _) SpScal = "." ppX :: PrettyX x => Expr x env t -> ADoc ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr) @@ -334,30 +408,42 @@ operator ORecip{} = (Prefix, "recip") operator OExp{} = (Prefix, "exp") operator OLog{} = (Prefix, "log") operator OIDiv{} = (Infix, "`div`") +operator OMod{} = (Infix, "`mod`") ppSTy :: Int -> STy t -> String -ppSTy d ty = ppTy d (unSTy ty) +ppSTy d ty = render $ ppSTy' d ty ppSTy' :: Int -> STy t -> Doc q -ppSTy' d ty = ppTy' d (unSTy ty) - -ppTy :: Int -> Ty -> String -ppTy d ty = render $ ppTy' d ty - -ppTy' :: Int -> Ty -> Doc q -ppTy' _ TNil = ppString "1" -ppTy' d (TPair a b) = ppParen (d > 7) $ ppTy' 8 a <> ppString " * " <> ppTy' 8 b -ppTy' d (TEither a b) = ppParen (d > 6) $ ppTy' 7 a <> ppString " + " <> ppTy' 7 b -ppTy' d (TMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppTy' 11 t -ppTy' d (TArr n t) = ppParen (d > 10) $ - ppString "Arr " <> ppString (show (fromNat n)) <> ppString " " <> ppTy' 11 t -ppTy' _ (TScal sty) = ppString $ case sty of - TI32 -> "i32" - TI64 -> "i64" - TF32 -> "f32" - TF64 -> "f64" - TBool -> "bool" -ppTy' d (TAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppTy' 11 t +ppSTy' _ STNil = ppString "1" +ppSTy' d (STPair a b) = ppParen (d > 7) $ ppSTy' 8 a <> ppString " * " <> ppSTy' 8 b +ppSTy' d (STEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " + " <> ppSTy' 7 b +ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b +ppSTy' d (STMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSTy' 11 t +ppSTy' d (STArr n t) = ppParen (d > 10) $ + ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSTy' 11 t +ppSTy' _ (STScal sty) = ppString $ case sty of + STI32 -> "i32" + STI64 -> "i64" + STF32 -> "f32" + STF64 -> "f64" + STBool -> "bool" +ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSMTy' 11 t + +ppSMTy :: Int -> SMTy t -> String +ppSMTy d ty = render $ ppSMTy' d ty + +ppSMTy' :: Int -> SMTy t -> Doc q +ppSMTy' _ SMTNil = ppString "1" +ppSMTy' d (SMTPair a b) = ppParen (d > 7) $ ppSMTy' 8 a <> ppString " * " <> ppSMTy' 8 b +ppSMTy' d (SMTLEither a b) = ppParen (d > 6) $ ppSMTy' 7 a <> ppString " ⊕ " <> ppSMTy' 7 b +ppSMTy' d (SMTMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSMTy' 11 t +ppSMTy' d (SMTArr n t) = ppParen (d > 10) $ + ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSMTy' 11 t +ppSMTy' _ (SMTScal sty) = ppString $ case sty of + STI32 -> "i32" + STI64 -> "i64" + STF32 -> "f32" + STF64 -> "f64" ppString :: String -> Doc x ppString = fromString |