diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-12-09 21:48:39 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-12-09 21:48:39 +0100 |
| commit | 8cdef3f10594a69037d45340029f8e795ecfee4a (patch) | |
| tree | b0edf7e7692c4da34bcbb4332607d1511975116f /src/CHAD/Fusion/AST.hs | |
| parent | d5ea985f9d252af55ea0a5c3f00374a41b562369 (diff) | |
WIP fusion stufffusion
Diffstat (limited to 'src/CHAD/Fusion/AST.hs')
| -rw-r--r-- | src/CHAD/Fusion/AST.hs | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/src/CHAD/Fusion/AST.hs b/src/CHAD/Fusion/AST.hs index 3cd188a..a84e575 100644 --- a/src/CHAD/Fusion/AST.hs +++ b/src/CHAD/Fusion/AST.hs @@ -26,8 +26,8 @@ data Node env t where FFreeVar :: STy t -> Idx env t -> Node env t FLoop :: SList NodeId args -> SList STy outs - -> LoopNest args outs - -> Tuple (Idx outs) t + -> LoopNest args outs bouts + -> Tuple (Idx (Append outs bouts)) t -> Node env t data NodeId t = NodeId Natural (STy t) @@ -39,19 +39,21 @@ data Tuple f t where TupSingle :: f t -> Tuple f t deriving instance (forall a. Show (f a)) => Show (Tuple f t) -data LoopNest args outs where +-- bouts: "build outs", outputs that were marked as build-style (elementwise) +-- above and cannot be handled differently any more +data LoopNest args outs bouts where Inner :: Bindings Ex args bs -> SList (Idx (Append bs args)) outs - -> LoopNest args outs + -> LoopNest args outs '[] -- this should be able to express a simple nesting of builds and sums. Layer :: Bindings Ex args bs1 -> Idx bs1 TIx -- ^ loop width (number of (parallel) iterations) - -> LoopNest (TIx : Append bs1 args) loopouts - -> Partition BuildUp RedSum loopouts mapouts sumouts + -> LoopNest (TIx : Append bs1 args) loopouts bouts + -> Partition BuildUp RedSum loopouts newbouts sumouts -> Bindings Ex (Append sumouts (Append bs1 args)) bs2 -> SList (Idx (Append bs2 (Append bs1 args))) outs - -> LoopNest args (Append outs mapouts) -deriving instance Show (LoopNest args outs) + -> LoopNest args outs (Append newbouts bouts) +deriving instance Show (LoopNest args outs bouts) type Partition :: (Ty -> Ty -> Type) -> (Ty -> Ty -> Type) -> [Ty] -> [Ty] -> [Ty] -> Type data Partition f1 f2 ts ts1 ts2 where |
