diff options
Diffstat (limited to 'src/CHAD/AST/Pretty.hs')
| -rw-r--r-- | src/CHAD/AST/Pretty.hs | 96 |
1 files changed, 87 insertions, 9 deletions
diff --git a/src/CHAD/AST/Pretty.hs b/src/CHAD/AST/Pretty.hs index b763efe..a9a8987 100644 --- a/src/CHAD/AST/Pretty.hs +++ b/src/CHAD/AST/Pretty.hs @@ -5,10 +5,17 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TupleSections #-} -module CHAD.AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where +module CHAD.AST.Pretty ( + pprintExpr, + ppExpr, + ppLoopNest, + ppSTy, ppSMTy, + PrettyX(..), +) where import Control.Monad (ap) import Data.List (intersperse, intercalate) +import Data.Foldable (toList) import Data.Functor.Const import Data.Functor.Product qualified as Product import Data.String (fromString) @@ -22,10 +29,12 @@ import System.IO (stdout) import System.IO.Unsafe (unsafePerformIO) import CHAD.AST +import CHAD.AST.Bindings import CHAD.AST.Count import CHAD.AST.Sparse.Types import CHAD.Data import CHAD.Drev.Types +import CHAD.Fusion.AST class PrettyX x where @@ -78,7 +87,7 @@ pprintExpr = putStrLn . ppExpr knownEnv ppExpr :: PrettyX x => SList STy env -> Expr NoExt x env t -> String ppExpr senv e = render $ fst . flip runM 1 $ do - val <- mkVal senv + val <- mkSVal senv e' <- ppExpr' 0 val e let lam = "λ" ++ intercalate " " (reverse (unSList (\(Product.Pair (Const name) ty) -> "(" ++ name ++ " : " ++ ppSTy 0 ty ++ ")") (slistZip val senv))) ++ "." return $ group $ flatAlt @@ -86,13 +95,13 @@ ppExpr senv e = render $ fst . flip runM 1 $ do ppString lam <> hardline <> e') (ppString lam <+> e') - where - mkVal :: SList f env -> M (SVal env) - mkVal SNil = return SNil - mkVal (SCons _ v) = do - val <- mkVal v - name <- genName' "arg" - return (Const name `SCons` val) + +mkSVal :: SList f env -> M (SVal env) +mkSVal SNil = return SNil +mkSVal (SCons _ v) = do + val <- mkSVal v + name <- genName' "arg" + return (Const name `SCons` val) ppExpr' :: PrettyX x => Int -> SVal env -> Expr NoExt x env t -> M ADoc ppExpr' d val expr = case expr of @@ -403,6 +412,75 @@ ppLam :: [ADoc] -> ADoc -> ADoc ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"]) <> softline <> body <> ppString ")") +ppLoopNest :: SList f args -> LoopNest args outs -> String +ppLoopNest senv lnest = render $ fst . flip runM 1 $ do + val <- mkSVal senv + ppLoopNest' val lnest + +data RedKind = RKRet | RKBuild | RKSum + +ppLoopNest' :: SVal args -> LoopNest args outs -> M ADoc +ppLoopNest' = \env lnest -> do + (f, outs) <- go env lnest + return (f (slistMap (\(Const _) -> Const RKRet) outs)) + where + go :: SVal args -> LoopNest args outs -> M (SList (Const RedKind) outs -> ADoc, SVal outs) + go env (Inner bs outs) = do + (bs', names) <- goBindings env bs + let outNames = slistMap (\i -> slistIdx (sappend names env) i) outs + outDoc kinds = ppString "[" + <> mconcat (map ppString (intersperse ", " (zipWith decorate kinds (unSList getConst outNames)))) + <> ppString "]" + return (\kinds -> + vcat (toList bs') + <> hardline <> (annotate AKey (ppString "ret") <+> outDoc (unSList getConst kinds)) + ,outNames) + go env (Layer bs1 wid lnest part bs2 outs) = do + (bs1', names1) <- goBindings env bs1 + widname <- genName' "i" + (f, loopouts) <- go (Const widname `SCons` sappend names1 env) lnest + let (redkinds, mapouts, sumouts) = partition part loopouts + let lnest' = f redkinds + (bs2', names2) <- goBindings (sappend sumouts (sappend names1 env)) bs2 + let outNames = slistMap (\i -> slistIdx (sappend names2 (sappend names1 env)) i) outs + outDoc kinds = + [annotate AKey (ppString "ret") <+> + ppString "[" + <> mconcat (map ppString (intersperse ", " (unSList _ (slistZip kinds outNames)))) + -- <> ppString "] ++ [" + -- <> mconcat (map ppString (intersperse ", " (unSList getConst mapouts))) + <> ppString "]"] + return (\kinds -> + vcat (toList bs1' ++ + [hang 2 (annotate AKey (ppString "loop") <+> ppString widname <+> annotate AKey (ppString "to") <+> ppString (getConst (slistIdx names1 wid)) + <> hardline <> lnest')] ++ + toList bs2' ++ + outDoc kinds) + ,sappend outNames mapouts) + + decorate :: RedKind -> String -> String + decorate RKRet name = name + decorate RKBuild name = "#" ++ name + decorate RKSum name = "↓" ++ name + + partition :: Partition BuildUp RedSum loopouts mapouts sumouts -> SVal loopouts -> (SList (Const RedKind) loopouts, SVal mapouts, SVal sumouts) + partition PNil SNil = (SNil, SNil, SNil) + partition (Part1 BuildUp{} part) (Const name `SCons` names) = + let (loopouts, mapouts, sumouts) = partition part names + in (Const RKBuild `SCons` loopouts, Const name `SCons` mapouts, sumouts) + partition (Part2 RedSum{} part) (Const name `SCons` names) = + let (loopouts, mapouts, sumouts) = partition part names + in (Const RKSum `SCons` loopouts, mapouts, Const name `SCons` sumouts) + + goBindings :: SVal env -> Bindings Ex env bs -> M (Bag ADoc, SVal bs) + goBindings _ BTop = return (mempty, SNil) + goBindings env (bs `BPush` (_, e)) = do + (docs, names) <- goBindings env bs + name <- genName' "x" + e' <- ppExpr' 0 (sappend names env) e + let doc = ppString (name ++ " = ") <> e' + return (docs <> pure doc, Const name `SCons` names) + ppAcPrj :: SMTy a -> SAcPrj p a b -> String ppAcPrj _ SAPHere = "." ppAcPrj (SMTPair t _) (SAPFst prj) = "(" ++ ppAcPrj t prj ++ ",)" |
