{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module CHAD.Fusion where import Data.Dependent.Map (DMap) -- import Data.Dependent.Map qualified as DMap import Data.Functor.Const import Data.Kind (Type) import Numeric.Natural import CHAD.AST import CHAD.AST.Bindings import CHAD.Data -- TODO: -- A bunch of data types are defined here that should be able to express a -- graph of loop nests. A graph is a straight-line program whose statements -- are, in this case, loop nests. A loop nest corresponds to what fusion -- normally calls a "cluster", but is here represented as, well, a loop nest. -- -- No unzipping is done here, as I don't think it is necessary: I haven't been -- able to think of programs that get more fusion opportunities when unzipped -- than when zipped. If any such programs exist, I in any case conjecture that -- with a pre-pass that splits array operations that can be unzipped already at -- the source-level (e.g. build n (\i -> (E1, E2)) -> zip (build n (\i -> E1), -- build n (\i -> E2))), all such fusion opportunities can be recovered. If -- this conjecture is false, some reorganisation may be required. -- -- Next steps, perhaps: -- 1. Express a build operation as a LoopNest, not from the EBuild constructor -- specifically but its fields. It will have a single output, and its args -- will be its list of free variables. -- 2. Express a sum operation as a LoopNest in the same way; 1 arg, 1 out. -- 3. Write a "recognition" pass that eagerly constructs graphs for subterms of -- a large expression that contain only "simple" AST constructors, and -- replaces those subterms with an EExt containing that graph. In this -- construction process, EBuild and ESum1Inner should be replaced with -- FLoop. -- 4. Implement fusion somehow on graphs! -- 5. Add an AST constructor for a loop nest (which most of the modules throw -- an error on, except Count, Simplify and Compile -- good luck with Count), -- and compile that to an actual C loop nest. -- 6. Extend to other cool operations like EFold1InnerD1 type FEx = Expr FGraph (Const ()) type FGraph :: (Ty -> Type) -> [Ty] -> Ty -> Type data FGraph x env t where FGraph :: DMap NodeId (Node env) -> Tuple NodeId t -> FGraph (Const ()) env t data Node env t where FFreeVar :: STy t -> Idx env t -> Node env t FLoop :: SList NodeId args -> SList STy outs -> LoopNest args outs -> Tuple (Idx outs) t -> Node env t data NodeId t = NodeId Natural (STy t) deriving (Show) data Tuple f t where TupNil :: Tuple f TNil TupPair :: Tuple f a -> Tuple f b -> Tuple f (TPair a b) TupSingle :: f t -> Tuple f t deriving instance (forall a. Show (f a)) => Show (Tuple f t) data LoopNest args outs where Inner :: Bindings Ex args bs -> SList (Idx (Append bs args)) outs -> LoopNest args outs -- this should be able to express a simple nesting of builds and sums. Layer :: Bindings Ex args bs1 -> Idx bs1 TIx -- ^ loop width (number of (parallel) iterations) -> LoopNest (TIx : Append bs1 args) loopouts -> Partition BuildUp RedSum loopouts mapouts sumouts -> Bindings Ex (Append sumouts (Append bs1 args)) bs2 -> SList (Idx (Append bs2 args)) outs -> LoopNest args (Append outs mapouts) type Partition :: (Ty -> Ty -> Type) -> (Ty -> Ty -> Type) -> [Ty] -> [Ty] -> [Ty] -> Type data Partition f1 f2 ts ts1 ts2 where PNil :: Partition f1 f2 '[] '[] '[] Part1 :: f1 t t1 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) (t1 : ts1) ts2 Part2 :: f2 t t2 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) ts1 (t2 : ts2) data BuildUp t t' where BuildUp :: SNat n -> STy t -> BuildUp (TArr n t) (TArr (S n) t) data RedSum t t' where RedSum :: SMTy t -> RedSum t t -- type family Unzip t where -- Unzip (TPair a b) = TPair (Unzip a) (Unzip b) -- Unzip (TArr n t) = UnzipA n t -- type family UnzipA n t where -- UnzipA n (TPair a b) = TPair (UnzipA n a) (UnzipA n b) -- UnzipA n t = TArr n t -- data Zipping ut t where -- ZId :: Zipping t t -- ZPair :: Zipping ua a -> Zipping ub b -> Zipping (TPair ua ub) (TPair a b) -- ZZip :: Zipping ua (TArr n a) -> Zipping ub (TArr n b) -> Zipping (TPair ua ub) (TArr n (TPair a b)) -- deriving instance Show (Zipping ut t)