aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/Fusion.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD/Fusion.hs')
-rw-r--r--src/CHAD/Fusion.hs115
1 files changed, 115 insertions, 0 deletions
diff --git a/src/CHAD/Fusion.hs b/src/CHAD/Fusion.hs
new file mode 100644
index 0000000..757667f
--- /dev/null
+++ b/src/CHAD/Fusion.hs
@@ -0,0 +1,115 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.Fusion where
+
+import Data.Dependent.Map (DMap)
+-- import Data.Dependent.Map qualified as DMap
+import Data.Functor.Const
+import Data.Kind (Type)
+import Numeric.Natural
+
+import CHAD.AST
+import CHAD.AST.Bindings
+import CHAD.Data
+
+
+-- TODO:
+-- A bunch of data types are defined here that should be able to express a
+-- graph of loop nests. A graph is a straight-line program whose statements
+-- are, in this case, loop nests. A loop nest corresponds to what fusion
+-- normally calls a "cluster", but is here represented as, well, a loop nest.
+--
+-- No unzipping is done here, as I don't think it is necessary: I haven't been
+-- able to think of programs that get more fusion opportunities when unzipped
+-- than when zipped. If any such programs exist, I in any case conjecture that
+-- with a pre-pass that splits array operations that can be unzipped already at
+-- the source-level (e.g. build n (\i -> (E1, E2)) -> zip (build n (\i -> E1),
+-- build n (\i -> E2))), all such fusion opportunities can be recovered. If
+-- this conjecture is false, some reorganisation may be required.
+--
+-- Next steps, perhaps:
+-- 1. Express a build operation as a LoopNest, not from the EBuild constructor
+-- specifically but its fields. It will have a single output, and its args
+-- will be its list of free variables.
+-- 2. Express a sum operation as a LoopNest in the same way; 1 arg, 1 out.
+-- 3. Write a "recognition" pass that eagerly constructs graphs for subterms of
+-- a large expression that contain only "simple" AST constructors, and
+-- replaces those subterms with an EExt containing that graph. In this
+-- construction process, EBuild and ESum1Inner should be replaced with
+-- FLoop.
+-- 4. Implement fusion somehow on graphs!
+-- 5. Add an AST constructor for a loop nest (which most of the modules throw
+-- an error on, except Count, Simplify and Compile -- good luck with Count),
+-- and compile that to an actual C loop nest.
+-- 6. Extend to other cool operations like EFold1InnerD1
+
+
+type FEx = Expr FGraph (Const ())
+
+type FGraph :: (Ty -> Type) -> [Ty] -> Ty -> Type
+data FGraph x env t where
+ FGraph :: DMap NodeId (Node env) -> Tuple NodeId t -> FGraph (Const ()) env t
+
+data Node env t where
+ FFreeVar :: STy t -> Idx env t -> Node env t
+ FLoop :: SList NodeId args
+ -> SList STy outs
+ -> LoopNest args outs
+ -> Tuple (Idx outs) t
+ -> Node env t
+
+data NodeId t = NodeId Natural (STy t)
+ deriving (Show)
+
+data Tuple f t where
+ TupNil :: Tuple f TNil
+ TupPair :: Tuple f a -> Tuple f b -> Tuple f (TPair a b)
+ TupSingle :: f t -> Tuple f t
+deriving instance (forall a. Show (f a)) => Show (Tuple f t)
+
+data LoopNest args outs where
+ Inner :: Bindings Ex args bs
+ -> SList (Idx (Append bs args)) outs
+ -> LoopNest args outs
+ -- this should be able to express a simple nesting of builds and sums.
+ Layer :: Bindings Ex args bs1
+ -> Idx bs1 TIx -- ^ loop width (number of (parallel) iterations)
+ -> LoopNest (TIx : Append bs1 args) loopouts
+ -> Partition BuildUp RedSum loopouts mapouts sumouts
+ -> Bindings Ex (Append sumouts (Append bs1 args)) bs2
+ -> SList (Idx (Append bs2 args)) outs
+ -> LoopNest args (Append outs mapouts)
+
+type Partition :: (Ty -> Ty -> Type) -> (Ty -> Ty -> Type) -> [Ty] -> [Ty] -> [Ty] -> Type
+data Partition f1 f2 ts ts1 ts2 where
+ PNil :: Partition f1 f2 '[] '[] '[]
+ Part1 :: f1 t t1 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) (t1 : ts1) ts2
+ Part2 :: f2 t t2 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) ts1 (t2 : ts2)
+
+data BuildUp t t' where
+ BuildUp :: SNat n -> STy t -> BuildUp (TArr n t) (TArr (S n) t)
+
+data RedSum t t' where
+ RedSum :: SMTy t -> RedSum t t
+
+-- type family Unzip t where
+-- Unzip (TPair a b) = TPair (Unzip a) (Unzip b)
+-- Unzip (TArr n t) = UnzipA n t
+
+-- type family UnzipA n t where
+-- UnzipA n (TPair a b) = TPair (UnzipA n a) (UnzipA n b)
+-- UnzipA n t = TArr n t
+
+-- data Zipping ut t where
+-- ZId :: Zipping t t
+-- ZPair :: Zipping ua a -> Zipping ub b -> Zipping (TPair ua ub) (TPair a b)
+-- ZZip :: Zipping ua (TArr n a) -> Zipping ub (TArr n b) -> Zipping (TPair ua ub) (TArr n (TPair a b))
+-- deriving instance Show (Zipping ut t)
+
+