diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-22 22:41:09 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-22 22:41:09 +0100 |
| commit | 9b7c3eea7e34f5eb0d91f93b803e853028c2cec8 (patch) | |
| tree | 25b906bb49218d2743631d0c83e23717012e3b9b /src/CHAD/Fusion.hs | |
| parent | b4f07c673b7c710f5861bb84e67233c63336c53d (diff) | |
WIP: Think about fusionfusion
Diffstat (limited to 'src/CHAD/Fusion.hs')
| -rw-r--r-- | src/CHAD/Fusion.hs | 115 |
1 files changed, 115 insertions, 0 deletions
diff --git a/src/CHAD/Fusion.hs b/src/CHAD/Fusion.hs new file mode 100644 index 0000000..757667f --- /dev/null +++ b/src/CHAD/Fusion.hs @@ -0,0 +1,115 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.Fusion where + +import Data.Dependent.Map (DMap) +-- import Data.Dependent.Map qualified as DMap +import Data.Functor.Const +import Data.Kind (Type) +import Numeric.Natural + +import CHAD.AST +import CHAD.AST.Bindings +import CHAD.Data + + +-- TODO: +-- A bunch of data types are defined here that should be able to express a +-- graph of loop nests. A graph is a straight-line program whose statements +-- are, in this case, loop nests. A loop nest corresponds to what fusion +-- normally calls a "cluster", but is here represented as, well, a loop nest. +-- +-- No unzipping is done here, as I don't think it is necessary: I haven't been +-- able to think of programs that get more fusion opportunities when unzipped +-- than when zipped. If any such programs exist, I in any case conjecture that +-- with a pre-pass that splits array operations that can be unzipped already at +-- the source-level (e.g. build n (\i -> (E1, E2)) -> zip (build n (\i -> E1), +-- build n (\i -> E2))), all such fusion opportunities can be recovered. If +-- this conjecture is false, some reorganisation may be required. +-- +-- Next steps, perhaps: +-- 1. Express a build operation as a LoopNest, not from the EBuild constructor +-- specifically but its fields. It will have a single output, and its args +-- will be its list of free variables. +-- 2. Express a sum operation as a LoopNest in the same way; 1 arg, 1 out. +-- 3. Write a "recognition" pass that eagerly constructs graphs for subterms of +-- a large expression that contain only "simple" AST constructors, and +-- replaces those subterms with an EExt containing that graph. In this +-- construction process, EBuild and ESum1Inner should be replaced with +-- FLoop. +-- 4. Implement fusion somehow on graphs! +-- 5. Add an AST constructor for a loop nest (which most of the modules throw +-- an error on, except Count, Simplify and Compile -- good luck with Count), +-- and compile that to an actual C loop nest. +-- 6. Extend to other cool operations like EFold1InnerD1 + + +type FEx = Expr FGraph (Const ()) + +type FGraph :: (Ty -> Type) -> [Ty] -> Ty -> Type +data FGraph x env t where + FGraph :: DMap NodeId (Node env) -> Tuple NodeId t -> FGraph (Const ()) env t + +data Node env t where + FFreeVar :: STy t -> Idx env t -> Node env t + FLoop :: SList NodeId args + -> SList STy outs + -> LoopNest args outs + -> Tuple (Idx outs) t + -> Node env t + +data NodeId t = NodeId Natural (STy t) + deriving (Show) + +data Tuple f t where + TupNil :: Tuple f TNil + TupPair :: Tuple f a -> Tuple f b -> Tuple f (TPair a b) + TupSingle :: f t -> Tuple f t +deriving instance (forall a. Show (f a)) => Show (Tuple f t) + +data LoopNest args outs where + Inner :: Bindings Ex args bs + -> SList (Idx (Append bs args)) outs + -> LoopNest args outs + -- this should be able to express a simple nesting of builds and sums. + Layer :: Bindings Ex args bs1 + -> Idx bs1 TIx -- ^ loop width (number of (parallel) iterations) + -> LoopNest (TIx : Append bs1 args) loopouts + -> Partition BuildUp RedSum loopouts mapouts sumouts + -> Bindings Ex (Append sumouts (Append bs1 args)) bs2 + -> SList (Idx (Append bs2 args)) outs + -> LoopNest args (Append outs mapouts) + +type Partition :: (Ty -> Ty -> Type) -> (Ty -> Ty -> Type) -> [Ty] -> [Ty] -> [Ty] -> Type +data Partition f1 f2 ts ts1 ts2 where + PNil :: Partition f1 f2 '[] '[] '[] + Part1 :: f1 t t1 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) (t1 : ts1) ts2 + Part2 :: f2 t t2 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) ts1 (t2 : ts2) + +data BuildUp t t' where + BuildUp :: SNat n -> STy t -> BuildUp (TArr n t) (TArr (S n) t) + +data RedSum t t' where + RedSum :: SMTy t -> RedSum t t + +-- type family Unzip t where +-- Unzip (TPair a b) = TPair (Unzip a) (Unzip b) +-- Unzip (TArr n t) = UnzipA n t + +-- type family UnzipA n t where +-- UnzipA n (TPair a b) = TPair (UnzipA n a) (UnzipA n b) +-- UnzipA n t = TArr n t + +-- data Zipping ut t where +-- ZId :: Zipping t t +-- ZPair :: Zipping ua a -> Zipping ub b -> Zipping (TPair ua ub) (TPair a b) +-- ZZip :: Zipping ua (TArr n a) -> Zipping ub (TArr n b) -> Zipping (TPair ua ub) (TArr n (TPair a b)) +-- deriving instance Show (Zipping ut t) + + |
