From 8cdef3f10594a69037d45340029f8e795ecfee4a Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Tue, 9 Dec 2025 21:48:39 +0100 Subject: WIP fusion stuff --- src/CHAD/AST/Pretty.hs | 20 +++++++++++--------- src/CHAD/Fusion.hs | 8 ++++---- src/CHAD/Fusion/AST.hs | 18 ++++++++++-------- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/src/CHAD/AST/Pretty.hs b/src/CHAD/AST/Pretty.hs index a9a8987..04c3d30 100644 --- a/src/CHAD/AST/Pretty.hs +++ b/src/CHAD/AST/Pretty.hs @@ -412,19 +412,19 @@ ppLam :: [ADoc] -> ADoc -> ADoc ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"]) <> softline <> body <> ppString ")") -ppLoopNest :: SList f args -> LoopNest args outs -> String +ppLoopNest :: SList f args -> LoopNest args outs bouts -> String ppLoopNest senv lnest = render $ fst . flip runM 1 $ do val <- mkSVal senv ppLoopNest' val lnest data RedKind = RKRet | RKBuild | RKSum -ppLoopNest' :: SVal args -> LoopNest args outs -> M ADoc +ppLoopNest' :: SVal args -> LoopNest args outs bouts -> M ADoc ppLoopNest' = \env lnest -> do - (f, outs) <- go env lnest + (f, outs, _bouts) <- go env lnest return (f (slistMap (\(Const _) -> Const RKRet) outs)) where - go :: SVal args -> LoopNest args outs -> M (SList (Const RedKind) outs -> ADoc, SVal outs) + go :: SVal args -> LoopNest args outs bouts -> M (SList (Const RedKind) outs -> ADoc, SVal outs, SVal bouts) go env (Inner bs outs) = do (bs', names) <- goBindings env bs let outNames = slistMap (\i -> slistIdx (sappend names env) i) outs @@ -434,19 +434,20 @@ ppLoopNest' = \env lnest -> do return (\kinds -> vcat (toList bs') <> hardline <> (annotate AKey (ppString "ret") <+> outDoc (unSList getConst kinds)) - ,outNames) + ,outNames + ,SNil) go env (Layer bs1 wid lnest part bs2 outs) = do (bs1', names1) <- goBindings env bs1 widname <- genName' "i" - (f, loopouts) <- go (Const widname `SCons` sappend names1 env) lnest - let (redkinds, mapouts, sumouts) = partition part loopouts + (f, loopouts, loopbouts) <- go (Const widname `SCons` sappend names1 env) lnest + let (redkinds, newbouts, sumouts) = partition part loopouts let lnest' = f redkinds (bs2', names2) <- goBindings (sappend sumouts (sappend names1 env)) bs2 let outNames = slistMap (\i -> slistIdx (sappend names2 (sappend names1 env)) i) outs outDoc kinds = [annotate AKey (ppString "ret") <+> ppString "[" - <> mconcat (map ppString (intersperse ", " (unSList _ (slistZip kinds outNames)))) + <> mconcat (map ppString (intersperse ", " (unSList (\(Product.Pair (Const k) (Const n)) -> decorate k n) (slistZip kinds outNames)))) -- <> ppString "] ++ [" -- <> mconcat (map ppString (intersperse ", " (unSList getConst mapouts))) <> ppString "]"] @@ -456,7 +457,8 @@ ppLoopNest' = \env lnest -> do <> hardline <> lnest')] ++ toList bs2' ++ outDoc kinds) - ,sappend outNames mapouts) + ,outNames + ,sappend newbouts loopbouts) decorate :: RedKind -> String -> String decorate RKRet name = name diff --git a/src/CHAD/Fusion.hs b/src/CHAD/Fusion.hs index f863944..29c1f12 100644 --- a/src/CHAD/Fusion.hs +++ b/src/CHAD/Fusion.hs @@ -60,7 +60,7 @@ import CHAD.Lemmas -- putStrLn $ case fromNamed $ body $ build (SS (SS SZ)) (pair (pair nil 3) 4) (#idx :-> snd_ #idx + snd_ (fst_ #idx)) of EBuild _ n esh ebody -> let env = knownEnv in buildLoopNest env n esh ebody $ \sub nest -> show sub ++ "\n" ++ ppLoopNest (subList env sub) nest -prependBinding :: forall args outs t. Ex args t -> LoopNest (t : args) outs -> LoopNest args outs +prependBinding :: forall args outs bouts t. Ex args t -> LoopNest (t : args) outs bouts -> LoopNest args outs bouts prependBinding e (Inner (bs :: Bindings Ex (t : args) bs) outs) | Refl <- lemAppendAssoc @bs @'[t] @args = Inner (bconcat (BTop `bpush` e) bs) outs @@ -72,7 +72,7 @@ prependBinding e (Layer (bs1 :: Bindings Ex (t : args) bs1) wid nest part bs2 ou nest part bs2 outs buildLoopNest :: SList STy env -> SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex (Tup (Replicate n TIx) : env) t - -> (forall args. Subenv env args -> LoopNest args '[TArr n t] -> r) -> r + -> (forall args. Subenv env args -> LoopNest args '[] '[TArr n t] -> r) -> r buildLoopNest = \env sn esh ebody k -> withSome (occCountAll ebody) $ \occBody' -> occEnvPop' occBody' $ \occBody -> @@ -87,8 +87,8 @@ buildLoopNest = \env sn esh ebody k -> Inner (BTop `bpush` elet idx (EUnit ext (weakenExpr (WCopy (WPop w)) ebody'))) (IZ `SCons` SNil) where nestMapN :: SNat n -> STy t -> SList (Ex args) (Replicate n TIx) - -> (forall args'. args :> args' -> Ex args' (Tup (Replicate n TIx)) -> LoopNest args' '[TArr Z t]) - -> LoopNest args '[TArr n t] + -> (forall args'. args :> args' -> Ex args' (Tup (Replicate n TIx)) -> LoopNest args' '[TArr Z t] '[]) + -> LoopNest args '[] '[TArr n t] nestMapN SZ _ SNil inner = inner WId (ENil ext) nestMapN (SS sn) ty (wid `SCons` sh) inner = Layer (BTop `bpush` wid) diff --git a/src/CHAD/Fusion/AST.hs b/src/CHAD/Fusion/AST.hs index 3cd188a..a84e575 100644 --- a/src/CHAD/Fusion/AST.hs +++ b/src/CHAD/Fusion/AST.hs @@ -26,8 +26,8 @@ data Node env t where FFreeVar :: STy t -> Idx env t -> Node env t FLoop :: SList NodeId args -> SList STy outs - -> LoopNest args outs - -> Tuple (Idx outs) t + -> LoopNest args outs bouts + -> Tuple (Idx (Append outs bouts)) t -> Node env t data NodeId t = NodeId Natural (STy t) @@ -39,19 +39,21 @@ data Tuple f t where TupSingle :: f t -> Tuple f t deriving instance (forall a. Show (f a)) => Show (Tuple f t) -data LoopNest args outs where +-- bouts: "build outs", outputs that were marked as build-style (elementwise) +-- above and cannot be handled differently any more +data LoopNest args outs bouts where Inner :: Bindings Ex args bs -> SList (Idx (Append bs args)) outs - -> LoopNest args outs + -> LoopNest args outs '[] -- this should be able to express a simple nesting of builds and sums. Layer :: Bindings Ex args bs1 -> Idx bs1 TIx -- ^ loop width (number of (parallel) iterations) - -> LoopNest (TIx : Append bs1 args) loopouts - -> Partition BuildUp RedSum loopouts mapouts sumouts + -> LoopNest (TIx : Append bs1 args) loopouts bouts + -> Partition BuildUp RedSum loopouts newbouts sumouts -> Bindings Ex (Append sumouts (Append bs1 args)) bs2 -> SList (Idx (Append bs2 (Append bs1 args))) outs - -> LoopNest args (Append outs mapouts) -deriving instance Show (LoopNest args outs) + -> LoopNest args outs (Append newbouts bouts) +deriving instance Show (LoopNest args outs bouts) type Partition :: (Ty -> Ty -> Type) -> (Ty -> Ty -> Type) -> [Ty] -> [Ty] -> [Ty] -> Type data Partition f1 f2 ts ts1 ts2 where -- cgit v1.2.3-70-g09d2