summaryrefslogtreecommitdiff
path: root/src/AST/Pretty.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-04-27 23:34:59 +0200
committerTom Smeding <tom@tomsmeding.com>2025-04-27 23:34:59 +0200
commitb1664532eaebdf0409ab6d93fc0ba2ef8dfbf372 (patch)
treea40c16fd082bbe4183e7b4194b8cea1408cec379 /src/AST/Pretty.hs
parentc750f8f9f1275d49ff74297e6648e1bfc1c6d918 (diff)
WIP revamp accumulators again: explicit monoid types
No more D2 in accumulators! Paving the way for configurable sparsity of products and arrays. The idea is to make separate monoid types for a "product cotangent" and an "array cotangent" that can be lowered to either a sparse monoid or a non-sparse monoid. Downsides of this approach: lots of API duplication.
Diffstat (limited to 'src/AST/Pretty.hs')
-rw-r--r--src/AST/Pretty.hs76
1 files changed, 60 insertions, 16 deletions
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index fb5e138..b6ad7d2 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -7,7 +7,7 @@
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeOperators #-}
-module AST.Pretty (pprintExpr, ppExpr, ppSTy, PrettyX(..)) where
+module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where
import Control.Monad (ap)
import Data.List (intersperse, intercalate)
@@ -152,6 +152,31 @@ ppExpr' d val expr = case expr of
return $ ppParen (d > 10) $
ppApp (ppString "maybe" <> ppX expr) [a', ppLam [ppString name] b', e']
+ ELNil _ _ _ -> return (ppString "LNil")
+
+ ELInl _ _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppString "LInl" <> ppX expr <+> e'
+
+ ELInr _ _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppString "LInr" <> ppX expr <+> e'
+
+ ELCase _ e a b c -> do
+ e' <- ppExpr' 0 val e
+ let STLEither t1 t2 = typeOf e
+ a' <- ppExpr' 11 val a
+ name1 <- genNameIfUsedIn t1 IZ b
+ b' <- ppExpr' 0 (Const name1 `SCons` val) b
+ name2 <- genNameIfUsedIn t2 IZ c
+ c' <- ppExpr' 0 (Const name2 `SCons` val) c
+ return $ ppParen (d > 0) $
+ hang 2 $
+ annotate AKey (ppString "lcase") <> ppX expr <+> e' <+> annotate AKey (ppString "of")
+ <> hardline <> ppString "LNil" <+> ppString "->" <+> a'
+ <> hardline <> ppString "LInl" <+> ppString name1 <+> ppString "->" <+> b'
+ <> hardline <> ppString "LInr" <+> ppString name2 <+> ppString "->" <+> c'
+
EConstArr _ _ ty v
| Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr
@@ -267,15 +292,17 @@ ppExpr' d val expr = case expr of
<> hardline <> e2')
(ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2'])
- EAccum _ _ prj e1 e2 e3 -> do
+ EAccum _ t prj e1 e2 e3 -> do
e1' <- ppExpr' 11 val e1
e2' <- ppExpr' 11 val e2
e3' <- ppExpr' 11 val e3
return $ ppParen (d > 10) $
- ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (ppAcPrj prj), e1', e2', e3']
+ ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (ppAcPrj t prj), e1', e2', e3']
- EZero _ t -> return $ ppParen (d > 0) $
- annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSTy' 11 t
+ EZero _ t e1 -> do
+ e1' <- ppExpr' 11 val e1
+ return $ ppParen (d > 0) $
+ annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1'
EPlus _ _ a b -> do
a' <- ppExpr' 11 val a
@@ -283,11 +310,11 @@ ppExpr' d val expr = case expr of
return $ ppParen (d > 10) $
ppApp (annotate AMonoid (ppString "plus") <> ppX expr) [a', b']
- EOneHot _ _ prj a b -> do
+ EOneHot _ t prj a b -> do
a' <- ppExpr' 11 val a
b' <- ppExpr' 11 val b
return $ ppParen (d > 10) $
- ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (ppAcPrj prj), a', b']
+ ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (ppAcPrj t prj), a', b']
EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s)
@@ -320,14 +347,14 @@ ppLam :: [ADoc] -> ADoc -> ADoc
ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"])
<> softline <> body <> ppString ")")
-ppAcPrj :: SAcPrj p a b -> String
-ppAcPrj SAPHere = "@"
-ppAcPrj (SAPFst prj) = "(" ++ ppAcPrj prj ++ ",)"
-ppAcPrj (SAPSnd prj) = "(," ++ ppAcPrj prj ++ ")"
-ppAcPrj (SAPLeft prj) = "(" ++ ppAcPrj prj ++ "|)"
-ppAcPrj (SAPRight prj) = "(|" ++ ppAcPrj prj ++ ")"
-ppAcPrj (SAPJust prj) = "J" ++ ppAcPrj prj
-ppAcPrj (SAPArrIdx prj n) = "[" ++ ppAcPrj prj ++ "]" ++ intSubscript (fromSNat n)
+ppAcPrj :: SMTy a -> SAcPrj p a b -> String
+ppAcPrj _ SAPHere = "@"
+ppAcPrj (SMTPair t _) (SAPFst prj) = "(" ++ ppAcPrj t prj ++ ",)"
+ppAcPrj (SMTPair _ t) (SAPSnd prj) = "(," ++ ppAcPrj t prj ++ ")"
+ppAcPrj (SMTLEither t _) (SAPLeft prj) = "(" ++ ppAcPrj t prj ++ "|)"
+ppAcPrj (SMTLEither _ t) (SAPRight prj) = "(|" ++ ppAcPrj t prj ++ ")"
+ppAcPrj (SMTMaybe t) (SAPJust prj) = "J" ++ ppAcPrj t prj
+ppAcPrj (SMTArr n t) (SAPArrIdx prj) = "[" ++ ppAcPrj t prj ++ "]" ++ intSubscript (fromSNat n)
ppX :: PrettyX x => Expr x env t -> ADoc
ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr)
@@ -370,7 +397,24 @@ ppSTy' _ (STScal sty) = ppString $ case sty of
STF32 -> "f32"
STF64 -> "f64"
STBool -> "bool"
-ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSTy' 11 t
+ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSMTy' 11 t
+ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b
+
+ppSMTy :: Int -> SMTy t -> String
+ppSMTy d ty = render $ ppSMTy' d ty
+
+ppSMTy' :: Int -> SMTy t -> Doc q
+ppSMTy' _ SMTNil = ppString "1"
+ppSMTy' d (SMTPair a b) = ppParen (d > 7) $ ppSMTy' 8 a <> ppString " * " <> ppSMTy' 8 b
+ppSMTy' d (SMTLEither a b) = ppParen (d > 6) $ ppSMTy' 7 a <> ppString " ⊕ " <> ppSMTy' 7 b
+ppSMTy' d (SMTMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSMTy' 11 t
+ppSMTy' d (SMTArr n t) = ppParen (d > 10) $
+ ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSMTy' 11 t
+ppSMTy' _ (SMTScal sty) = ppString $ case sty of
+ STI32 -> "i32"
+ STI64 -> "i64"
+ STF32 -> "f32"
+ STF64 -> "f64"
ppString :: String -> Doc x
ppString = fromString