diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-26 22:00:39 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-26 22:00:39 +0100 |
| commit | d5ea985f9d252af55ea0a5c3f00374a41b562369 (patch) | |
| tree | a0663aac094b02c60935b6e651e4dd38fac99959 | |
| parent | d74a7b212f06fbfad1b7f578cb127613acfb3311 (diff) | |
WIP stuff
| -rw-r--r-- | chad-fast.cabal | 1 | ||||
| -rw-r--r-- | src/CHAD/AST/Pretty.hs | 96 | ||||
| -rw-r--r-- | src/CHAD/Fusion.hs | 82 | ||||
| -rw-r--r-- | src/CHAD/Fusion/AST.hs | 83 |
4 files changed, 180 insertions, 82 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index 834f1d7..6a9147a 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -45,6 +45,7 @@ library CHAD.ForwardAD.DualNumbers CHAD.ForwardAD.DualNumbers.Types CHAD.Fusion + CHAD.Fusion.AST CHAD.Interpreter -- CHAD.Interpreter.AccumOld CHAD.Interpreter.Rep diff --git a/src/CHAD/AST/Pretty.hs b/src/CHAD/AST/Pretty.hs index b763efe..a9a8987 100644 --- a/src/CHAD/AST/Pretty.hs +++ b/src/CHAD/AST/Pretty.hs @@ -5,10 +5,17 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TupleSections #-} -module CHAD.AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where +module CHAD.AST.Pretty ( + pprintExpr, + ppExpr, + ppLoopNest, + ppSTy, ppSMTy, + PrettyX(..), +) where import Control.Monad (ap) import Data.List (intersperse, intercalate) +import Data.Foldable (toList) import Data.Functor.Const import Data.Functor.Product qualified as Product import Data.String (fromString) @@ -22,10 +29,12 @@ import System.IO (stdout) import System.IO.Unsafe (unsafePerformIO) import CHAD.AST +import CHAD.AST.Bindings import CHAD.AST.Count import CHAD.AST.Sparse.Types import CHAD.Data import CHAD.Drev.Types +import CHAD.Fusion.AST class PrettyX x where @@ -78,7 +87,7 @@ pprintExpr = putStrLn . ppExpr knownEnv ppExpr :: PrettyX x => SList STy env -> Expr NoExt x env t -> String ppExpr senv e = render $ fst . flip runM 1 $ do - val <- mkVal senv + val <- mkSVal senv e' <- ppExpr' 0 val e let lam = "λ" ++ intercalate " " (reverse (unSList (\(Product.Pair (Const name) ty) -> "(" ++ name ++ " : " ++ ppSTy 0 ty ++ ")") (slistZip val senv))) ++ "." return $ group $ flatAlt @@ -86,13 +95,13 @@ ppExpr senv e = render $ fst . flip runM 1 $ do ppString lam <> hardline <> e') (ppString lam <+> e') - where - mkVal :: SList f env -> M (SVal env) - mkVal SNil = return SNil - mkVal (SCons _ v) = do - val <- mkVal v - name <- genName' "arg" - return (Const name `SCons` val) + +mkSVal :: SList f env -> M (SVal env) +mkSVal SNil = return SNil +mkSVal (SCons _ v) = do + val <- mkSVal v + name <- genName' "arg" + return (Const name `SCons` val) ppExpr' :: PrettyX x => Int -> SVal env -> Expr NoExt x env t -> M ADoc ppExpr' d val expr = case expr of @@ -403,6 +412,75 @@ ppLam :: [ADoc] -> ADoc -> ADoc ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"]) <> softline <> body <> ppString ")") +ppLoopNest :: SList f args -> LoopNest args outs -> String +ppLoopNest senv lnest = render $ fst . flip runM 1 $ do + val <- mkSVal senv + ppLoopNest' val lnest + +data RedKind = RKRet | RKBuild | RKSum + +ppLoopNest' :: SVal args -> LoopNest args outs -> M ADoc +ppLoopNest' = \env lnest -> do + (f, outs) <- go env lnest + return (f (slistMap (\(Const _) -> Const RKRet) outs)) + where + go :: SVal args -> LoopNest args outs -> M (SList (Const RedKind) outs -> ADoc, SVal outs) + go env (Inner bs outs) = do + (bs', names) <- goBindings env bs + let outNames = slistMap (\i -> slistIdx (sappend names env) i) outs + outDoc kinds = ppString "[" + <> mconcat (map ppString (intersperse ", " (zipWith decorate kinds (unSList getConst outNames)))) + <> ppString "]" + return (\kinds -> + vcat (toList bs') + <> hardline <> (annotate AKey (ppString "ret") <+> outDoc (unSList getConst kinds)) + ,outNames) + go env (Layer bs1 wid lnest part bs2 outs) = do + (bs1', names1) <- goBindings env bs1 + widname <- genName' "i" + (f, loopouts) <- go (Const widname `SCons` sappend names1 env) lnest + let (redkinds, mapouts, sumouts) = partition part loopouts + let lnest' = f redkinds + (bs2', names2) <- goBindings (sappend sumouts (sappend names1 env)) bs2 + let outNames = slistMap (\i -> slistIdx (sappend names2 (sappend names1 env)) i) outs + outDoc kinds = + [annotate AKey (ppString "ret") <+> + ppString "[" + <> mconcat (map ppString (intersperse ", " (unSList _ (slistZip kinds outNames)))) + -- <> ppString "] ++ [" + -- <> mconcat (map ppString (intersperse ", " (unSList getConst mapouts))) + <> ppString "]"] + return (\kinds -> + vcat (toList bs1' ++ + [hang 2 (annotate AKey (ppString "loop") <+> ppString widname <+> annotate AKey (ppString "to") <+> ppString (getConst (slistIdx names1 wid)) + <> hardline <> lnest')] ++ + toList bs2' ++ + outDoc kinds) + ,sappend outNames mapouts) + + decorate :: RedKind -> String -> String + decorate RKRet name = name + decorate RKBuild name = "#" ++ name + decorate RKSum name = "↓" ++ name + + partition :: Partition BuildUp RedSum loopouts mapouts sumouts -> SVal loopouts -> (SList (Const RedKind) loopouts, SVal mapouts, SVal sumouts) + partition PNil SNil = (SNil, SNil, SNil) + partition (Part1 BuildUp{} part) (Const name `SCons` names) = + let (loopouts, mapouts, sumouts) = partition part names + in (Const RKBuild `SCons` loopouts, Const name `SCons` mapouts, sumouts) + partition (Part2 RedSum{} part) (Const name `SCons` names) = + let (loopouts, mapouts, sumouts) = partition part names + in (Const RKSum `SCons` loopouts, mapouts, Const name `SCons` sumouts) + + goBindings :: SVal env -> Bindings Ex env bs -> M (Bag ADoc, SVal bs) + goBindings _ BTop = return (mempty, SNil) + goBindings env (bs `BPush` (_, e)) = do + (docs, names) <- goBindings env bs + name <- genName' "x" + e' <- ppExpr' 0 (sappend names env) e + let doc = ppString (name ++ " = ") <> e' + return (docs <> pure doc, Const name `SCons` names) + ppAcPrj :: SMTy a -> SAcPrj p a b -> String ppAcPrj _ SAPHere = "." ppAcPrj (SMTPair t _) (SAPFst prj) = "(" ++ ppAcPrj t prj ++ ",)" diff --git a/src/CHAD/Fusion.hs b/src/CHAD/Fusion.hs index 3358d30..f863944 100644 --- a/src/CHAD/Fusion.hs +++ b/src/CHAD/Fusion.hs @@ -11,26 +11,24 @@ {-# LANGUAGE TypeOperators #-} module CHAD.Fusion where -import Data.Dependent.Map (DMap) -- import Data.Dependent.Map qualified as DMap -import Data.Functor.Const -import Data.Kind (Type) import Data.Some -import Numeric.Natural import CHAD.AST import CHAD.AST.Bindings import CHAD.AST.Count import CHAD.AST.Env import CHAD.Data +import CHAD.Fusion.AST import CHAD.Lemmas -- TODO: --- A bunch of data types are defined here that should be able to express a --- graph of loop nests. A graph is a straight-line program whose statements --- are, in this case, loop nests. A loop nest corresponds to what fusion --- normally calls a "cluster", but is here represented as, well, a loop nest. +-- A bunch of data types are defined here (in CHAD.Fusion.AST) that should be +-- able to express a graph of loop nests. A graph is a straight-line program +-- whose statements are, in this case, loop nests. A loop nest corresponds to +-- what fusion normally calls a "cluster", but is here represented as, well, a +-- loop nest. -- -- No unzipping is done here, as I don't think it is necessary: I haven't been -- able to think of programs that get more fusion opportunities when unzipped @@ -57,72 +55,10 @@ import CHAD.Lemmas -- and compile that to an actual C loop nest. -- 6. Extend to other cool operations like EFold1InnerD1 +-- :m *CHAD.Fusion CHAD.AST.Pretty CHAD.Language +-- pprintExpr $ fromNamed $ body $ build (SS (SS SZ)) (pair (pair nil 3) 4) (#idx :-> snd_ #idx + snd_ (fst_ #idx)) +-- putStrLn $ case fromNamed $ body $ build (SS (SS SZ)) (pair (pair nil 3) 4) (#idx :-> snd_ #idx + snd_ (fst_ #idx)) of EBuild _ n esh ebody -> let env = knownEnv in buildLoopNest env n esh ebody $ \sub nest -> show sub ++ "\n" ++ ppLoopNest (subList env sub) nest -type FEx = Expr FGraph (Const ()) - -type FGraph :: (Ty -> Type) -> [Ty] -> Ty -> Type -data FGraph x env t where - FGraph :: DMap NodeId (Node env) -> Tuple NodeId t -> FGraph (Const ()) env t - -data Node env t where - FFreeVar :: STy t -> Idx env t -> Node env t - FLoop :: SList NodeId args - -> SList STy outs - -> LoopNest args outs - -> Tuple (Idx outs) t - -> Node env t - -data NodeId t = NodeId Natural (STy t) - deriving (Show) - -data Tuple f t where - TupNil :: Tuple f TNil - TupPair :: Tuple f a -> Tuple f b -> Tuple f (TPair a b) - TupSingle :: f t -> Tuple f t -deriving instance (forall a. Show (f a)) => Show (Tuple f t) - -data LoopNest args outs where - Inner :: Bindings Ex args bs - -> SList (Idx (Append bs args)) outs - -> LoopNest args outs - -- this should be able to express a simple nesting of builds and sums. - Layer :: Bindings Ex args bs1 - -> Idx bs1 TIx -- ^ loop width (number of (parallel) iterations) - -> LoopNest (TIx : Append bs1 args) loopouts - -> Partition BuildUp RedSum loopouts mapouts sumouts - -> Bindings Ex (Append sumouts (Append bs1 args)) bs2 - -> SList (Idx (Append bs2 (Append bs1 args))) outs - -> LoopNest args (Append outs mapouts) -deriving instance Show (LoopNest args outs) - -type Partition :: (Ty -> Ty -> Type) -> (Ty -> Ty -> Type) -> [Ty] -> [Ty] -> [Ty] -> Type -data Partition f1 f2 ts ts1 ts2 where - PNil :: Partition f1 f2 '[] '[] '[] - Part1 :: f1 t t1 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) (t1 : ts1) ts2 - Part2 :: f2 t t2 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) ts1 (t2 : ts2) -deriving instance (forall t t1. Show (f1 t t1), forall t t2. Show (f2 t t2)) => Show (Partition f1 f2 ts ts1 ts2) - -data BuildUp t t' where - BuildUp :: SNat n -> STy t -> BuildUp (TArr n t) (TArr (S n) t) -deriving instance Show (BuildUp t t') - -data RedSum t t' where - RedSum :: SMTy t -> RedSum t t -deriving instance Show (RedSum t t') - --- type family Unzip t where --- Unzip (TPair a b) = TPair (Unzip a) (Unzip b) --- Unzip (TArr n t) = UnzipA n t - --- type family UnzipA n t where --- UnzipA n (TPair a b) = TPair (UnzipA n a) (UnzipA n b) --- UnzipA n t = TArr n t - --- data Zipping ut t where --- ZId :: Zipping t t --- ZPair :: Zipping ua a -> Zipping ub b -> Zipping (TPair ua ub) (TPair a b) --- ZZip :: Zipping ua (TArr n a) -> Zipping ub (TArr n b) -> Zipping (TPair ua ub) (TArr n (TPair a b)) --- deriving instance Show (Zipping ut t) prependBinding :: forall args outs t. Ex args t -> LoopNest (t : args) outs -> LoopNest args outs prependBinding e (Inner (bs :: Bindings Ex (t : args) bs) outs) diff --git a/src/CHAD/Fusion/AST.hs b/src/CHAD/Fusion/AST.hs new file mode 100644 index 0000000..3cd188a --- /dev/null +++ b/src/CHAD/Fusion/AST.hs @@ -0,0 +1,83 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.Fusion.AST where + +import Data.Dependent.Map (DMap) +import Data.Functor.Const +import Data.Kind (Type) +import Numeric.Natural + +import CHAD.AST +import CHAD.AST.Bindings +import CHAD.Data + + +type FEx = Expr FGraph (Const ()) + +type FGraph :: (Ty -> Type) -> [Ty] -> Ty -> Type +data FGraph x env t where + FGraph :: DMap NodeId (Node env) -> Tuple NodeId t -> FGraph (Const ()) env t + +data Node env t where + FFreeVar :: STy t -> Idx env t -> Node env t + FLoop :: SList NodeId args + -> SList STy outs + -> LoopNest args outs + -> Tuple (Idx outs) t + -> Node env t + +data NodeId t = NodeId Natural (STy t) + deriving (Show) + +data Tuple f t where + TupNil :: Tuple f TNil + TupPair :: Tuple f a -> Tuple f b -> Tuple f (TPair a b) + TupSingle :: f t -> Tuple f t +deriving instance (forall a. Show (f a)) => Show (Tuple f t) + +data LoopNest args outs where + Inner :: Bindings Ex args bs + -> SList (Idx (Append bs args)) outs + -> LoopNest args outs + -- this should be able to express a simple nesting of builds and sums. + Layer :: Bindings Ex args bs1 + -> Idx bs1 TIx -- ^ loop width (number of (parallel) iterations) + -> LoopNest (TIx : Append bs1 args) loopouts + -> Partition BuildUp RedSum loopouts mapouts sumouts + -> Bindings Ex (Append sumouts (Append bs1 args)) bs2 + -> SList (Idx (Append bs2 (Append bs1 args))) outs + -> LoopNest args (Append outs mapouts) +deriving instance Show (LoopNest args outs) + +type Partition :: (Ty -> Ty -> Type) -> (Ty -> Ty -> Type) -> [Ty] -> [Ty] -> [Ty] -> Type +data Partition f1 f2 ts ts1 ts2 where + PNil :: Partition f1 f2 '[] '[] '[] + Part1 :: f1 t t1 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) (t1 : ts1) ts2 + Part2 :: f2 t t2 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) ts1 (t2 : ts2) +deriving instance (forall t t1. Show (f1 t t1), forall t t2. Show (f2 t t2)) => Show (Partition f1 f2 ts ts1 ts2) + +data BuildUp t t' where + BuildUp :: SNat n -> STy t -> BuildUp (TArr n t) (TArr (S n) t) +deriving instance Show (BuildUp t t') + +data RedSum t t' where + RedSum :: SMTy t -> RedSum t t +deriving instance Show (RedSum t t') + +-- type family Unzip t where +-- Unzip (TPair a b) = TPair (Unzip a) (Unzip b) +-- Unzip (TArr n t) = UnzipA n t + +-- type family UnzipA n t where +-- UnzipA n (TPair a b) = TPair (UnzipA n a) (UnzipA n b) +-- UnzipA n t = TArr n t + +-- data Zipping ut t where +-- ZId :: Zipping t t +-- ZPair :: Zipping ua a -> Zipping ub b -> Zipping (TPair ua ub) (TPair a b) +-- ZZip :: Zipping ua (TArr n a) -> Zipping ub (TArr n b) -> Zipping (TPair ua ub) (TArr n (TPair a b)) +-- deriving instance Show (Zipping ut t) |
