aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-12-09 21:48:39 +0100
committerTom Smeding <tom@tomsmeding.com>2025-12-09 21:48:39 +0100
commit8cdef3f10594a69037d45340029f8e795ecfee4a (patch)
treeb0edf7e7692c4da34bcbb4332607d1511975116f
parentd5ea985f9d252af55ea0a5c3f00374a41b562369 (diff)
WIP fusion stufffusion
-rw-r--r--src/CHAD/AST/Pretty.hs20
-rw-r--r--src/CHAD/Fusion.hs8
-rw-r--r--src/CHAD/Fusion/AST.hs18
3 files changed, 25 insertions, 21 deletions
diff --git a/src/CHAD/AST/Pretty.hs b/src/CHAD/AST/Pretty.hs
index a9a8987..04c3d30 100644
--- a/src/CHAD/AST/Pretty.hs
+++ b/src/CHAD/AST/Pretty.hs
@@ -412,19 +412,19 @@ ppLam :: [ADoc] -> ADoc -> ADoc
ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"])
<> softline <> body <> ppString ")")
-ppLoopNest :: SList f args -> LoopNest args outs -> String
+ppLoopNest :: SList f args -> LoopNest args outs bouts -> String
ppLoopNest senv lnest = render $ fst . flip runM 1 $ do
val <- mkSVal senv
ppLoopNest' val lnest
data RedKind = RKRet | RKBuild | RKSum
-ppLoopNest' :: SVal args -> LoopNest args outs -> M ADoc
+ppLoopNest' :: SVal args -> LoopNest args outs bouts -> M ADoc
ppLoopNest' = \env lnest -> do
- (f, outs) <- go env lnest
+ (f, outs, _bouts) <- go env lnest
return (f (slistMap (\(Const _) -> Const RKRet) outs))
where
- go :: SVal args -> LoopNest args outs -> M (SList (Const RedKind) outs -> ADoc, SVal outs)
+ go :: SVal args -> LoopNest args outs bouts -> M (SList (Const RedKind) outs -> ADoc, SVal outs, SVal bouts)
go env (Inner bs outs) = do
(bs', names) <- goBindings env bs
let outNames = slistMap (\i -> slistIdx (sappend names env) i) outs
@@ -434,19 +434,20 @@ ppLoopNest' = \env lnest -> do
return (\kinds ->
vcat (toList bs')
<> hardline <> (annotate AKey (ppString "ret") <+> outDoc (unSList getConst kinds))
- ,outNames)
+ ,outNames
+ ,SNil)
go env (Layer bs1 wid lnest part bs2 outs) = do
(bs1', names1) <- goBindings env bs1
widname <- genName' "i"
- (f, loopouts) <- go (Const widname `SCons` sappend names1 env) lnest
- let (redkinds, mapouts, sumouts) = partition part loopouts
+ (f, loopouts, loopbouts) <- go (Const widname `SCons` sappend names1 env) lnest
+ let (redkinds, newbouts, sumouts) = partition part loopouts
let lnest' = f redkinds
(bs2', names2) <- goBindings (sappend sumouts (sappend names1 env)) bs2
let outNames = slistMap (\i -> slistIdx (sappend names2 (sappend names1 env)) i) outs
outDoc kinds =
[annotate AKey (ppString "ret") <+>
ppString "["
- <> mconcat (map ppString (intersperse ", " (unSList _ (slistZip kinds outNames))))
+ <> mconcat (map ppString (intersperse ", " (unSList (\(Product.Pair (Const k) (Const n)) -> decorate k n) (slistZip kinds outNames))))
-- <> ppString "] ++ ["
-- <> mconcat (map ppString (intersperse ", " (unSList getConst mapouts)))
<> ppString "]"]
@@ -456,7 +457,8 @@ ppLoopNest' = \env lnest -> do
<> hardline <> lnest')] ++
toList bs2' ++
outDoc kinds)
- ,sappend outNames mapouts)
+ ,outNames
+ ,sappend newbouts loopbouts)
decorate :: RedKind -> String -> String
decorate RKRet name = name
diff --git a/src/CHAD/Fusion.hs b/src/CHAD/Fusion.hs
index f863944..29c1f12 100644
--- a/src/CHAD/Fusion.hs
+++ b/src/CHAD/Fusion.hs
@@ -60,7 +60,7 @@ import CHAD.Lemmas
-- putStrLn $ case fromNamed $ body $ build (SS (SS SZ)) (pair (pair nil 3) 4) (#idx :-> snd_ #idx + snd_ (fst_ #idx)) of EBuild _ n esh ebody -> let env = knownEnv in buildLoopNest env n esh ebody $ \sub nest -> show sub ++ "\n" ++ ppLoopNest (subList env sub) nest
-prependBinding :: forall args outs t. Ex args t -> LoopNest (t : args) outs -> LoopNest args outs
+prependBinding :: forall args outs bouts t. Ex args t -> LoopNest (t : args) outs bouts -> LoopNest args outs bouts
prependBinding e (Inner (bs :: Bindings Ex (t : args) bs) outs)
| Refl <- lemAppendAssoc @bs @'[t] @args
= Inner (bconcat (BTop `bpush` e) bs) outs
@@ -72,7 +72,7 @@ prependBinding e (Layer (bs1 :: Bindings Ex (t : args) bs1) wid nest part bs2 ou
nest part bs2 outs
buildLoopNest :: SList STy env -> SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex (Tup (Replicate n TIx) : env) t
- -> (forall args. Subenv env args -> LoopNest args '[TArr n t] -> r) -> r
+ -> (forall args. Subenv env args -> LoopNest args '[] '[TArr n t] -> r) -> r
buildLoopNest = \env sn esh ebody k ->
withSome (occCountAll ebody) $ \occBody' ->
occEnvPop' occBody' $ \occBody ->
@@ -87,8 +87,8 @@ buildLoopNest = \env sn esh ebody k ->
Inner (BTop `bpush` elet idx (EUnit ext (weakenExpr (WCopy (WPop w)) ebody'))) (IZ `SCons` SNil)
where
nestMapN :: SNat n -> STy t -> SList (Ex args) (Replicate n TIx)
- -> (forall args'. args :> args' -> Ex args' (Tup (Replicate n TIx)) -> LoopNest args' '[TArr Z t])
- -> LoopNest args '[TArr n t]
+ -> (forall args'. args :> args' -> Ex args' (Tup (Replicate n TIx)) -> LoopNest args' '[TArr Z t] '[])
+ -> LoopNest args '[] '[TArr n t]
nestMapN SZ _ SNil inner = inner WId (ENil ext)
nestMapN (SS sn) ty (wid `SCons` sh) inner =
Layer (BTop `bpush` wid)
diff --git a/src/CHAD/Fusion/AST.hs b/src/CHAD/Fusion/AST.hs
index 3cd188a..a84e575 100644
--- a/src/CHAD/Fusion/AST.hs
+++ b/src/CHAD/Fusion/AST.hs
@@ -26,8 +26,8 @@ data Node env t where
FFreeVar :: STy t -> Idx env t -> Node env t
FLoop :: SList NodeId args
-> SList STy outs
- -> LoopNest args outs
- -> Tuple (Idx outs) t
+ -> LoopNest args outs bouts
+ -> Tuple (Idx (Append outs bouts)) t
-> Node env t
data NodeId t = NodeId Natural (STy t)
@@ -39,19 +39,21 @@ data Tuple f t where
TupSingle :: f t -> Tuple f t
deriving instance (forall a. Show (f a)) => Show (Tuple f t)
-data LoopNest args outs where
+-- bouts: "build outs", outputs that were marked as build-style (elementwise)
+-- above and cannot be handled differently any more
+data LoopNest args outs bouts where
Inner :: Bindings Ex args bs
-> SList (Idx (Append bs args)) outs
- -> LoopNest args outs
+ -> LoopNest args outs '[]
-- this should be able to express a simple nesting of builds and sums.
Layer :: Bindings Ex args bs1
-> Idx bs1 TIx -- ^ loop width (number of (parallel) iterations)
- -> LoopNest (TIx : Append bs1 args) loopouts
- -> Partition BuildUp RedSum loopouts mapouts sumouts
+ -> LoopNest (TIx : Append bs1 args) loopouts bouts
+ -> Partition BuildUp RedSum loopouts newbouts sumouts
-> Bindings Ex (Append sumouts (Append bs1 args)) bs2
-> SList (Idx (Append bs2 (Append bs1 args))) outs
- -> LoopNest args (Append outs mapouts)
-deriving instance Show (LoopNest args outs)
+ -> LoopNest args outs (Append newbouts bouts)
+deriving instance Show (LoopNest args outs bouts)
type Partition :: (Ty -> Ty -> Type) -> (Ty -> Ty -> Type) -> [Ty] -> [Ty] -> [Ty] -> Type
data Partition f1 f2 ts ts1 ts2 where