1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module CHAD.Fusion where
import Data.Dependent.Map (DMap)
-- import Data.Dependent.Map qualified as DMap
import Data.Functor.Const
import Data.Kind (Type)
import Numeric.Natural
import CHAD.AST
import CHAD.AST.Bindings
import CHAD.Data
-- TODO:
-- A bunch of data types are defined here that should be able to express a
-- graph of loop nests. A graph is a straight-line program whose statements
-- are, in this case, loop nests. A loop nest corresponds to what fusion
-- normally calls a "cluster", but is here represented as, well, a loop nest.
--
-- No unzipping is done here, as I don't think it is necessary: I haven't been
-- able to think of programs that get more fusion opportunities when unzipped
-- than when zipped. If any such programs exist, I in any case conjecture that
-- with a pre-pass that splits array operations that can be unzipped already at
-- the source-level (e.g. build n (\i -> (E1, E2)) -> zip (build n (\i -> E1),
-- build n (\i -> E2))), all such fusion opportunities can be recovered. If
-- this conjecture is false, some reorganisation may be required.
--
-- Next steps, perhaps:
-- 1. Express a build operation as a LoopNest, not from the EBuild constructor
-- specifically but its fields. It will have a single output, and its args
-- will be its list of free variables.
-- 2. Express a sum operation as a LoopNest in the same way; 1 arg, 1 out.
-- 3. Write a "recognition" pass that eagerly constructs graphs for subterms of
-- a large expression that contain only "simple" AST constructors, and
-- replaces those subterms with an EExt containing that graph. In this
-- construction process, EBuild and ESum1Inner should be replaced with
-- FLoop.
-- 4. Implement fusion somehow on graphs!
-- 5. Add an AST constructor for a loop nest (which most of the modules throw
-- an error on, except Count, Simplify and Compile -- good luck with Count),
-- and compile that to an actual C loop nest.
-- 6. Extend to other cool operations like EFold1InnerD1
type FEx = Expr FGraph (Const ())
type FGraph :: (Ty -> Type) -> [Ty] -> Ty -> Type
data FGraph x env t where
FGraph :: DMap NodeId (Node env) -> Tuple NodeId t -> FGraph (Const ()) env t
data Node env t where
FFreeVar :: STy t -> Idx env t -> Node env t
FLoop :: SList NodeId args
-> SList STy outs
-> LoopNest args outs
-> Tuple (Idx outs) t
-> Node env t
data NodeId t = NodeId Natural (STy t)
deriving (Show)
data Tuple f t where
TupNil :: Tuple f TNil
TupPair :: Tuple f a -> Tuple f b -> Tuple f (TPair a b)
TupSingle :: f t -> Tuple f t
deriving instance (forall a. Show (f a)) => Show (Tuple f t)
data LoopNest args outs where
Inner :: Bindings Ex args bs
-> SList (Idx (Append bs args)) outs
-> LoopNest args outs
-- this should be able to express a simple nesting of builds and sums.
Layer :: Bindings Ex args bs1
-> Idx bs1 TIx -- ^ loop width (number of (parallel) iterations)
-> LoopNest (TIx : Append bs1 args) loopouts
-> Partition BuildUp RedSum loopouts mapouts sumouts
-> Bindings Ex (Append sumouts (Append bs1 args)) bs2
-> SList (Idx (Append bs2 args)) outs
-> LoopNest args (Append outs mapouts)
type Partition :: (Ty -> Ty -> Type) -> (Ty -> Ty -> Type) -> [Ty] -> [Ty] -> [Ty] -> Type
data Partition f1 f2 ts ts1 ts2 where
PNil :: Partition f1 f2 '[] '[] '[]
Part1 :: f1 t t1 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) (t1 : ts1) ts2
Part2 :: f2 t t2 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) ts1 (t2 : ts2)
data BuildUp t t' where
BuildUp :: SNat n -> STy t -> BuildUp (TArr n t) (TArr (S n) t)
data RedSum t t' where
RedSum :: SMTy t -> RedSum t t
-- type family Unzip t where
-- Unzip (TPair a b) = TPair (Unzip a) (Unzip b)
-- Unzip (TArr n t) = UnzipA n t
-- type family UnzipA n t where
-- UnzipA n (TPair a b) = TPair (UnzipA n a) (UnzipA n b)
-- UnzipA n t = TArr n t
-- data Zipping ut t where
-- ZId :: Zipping t t
-- ZPair :: Zipping ua a -> Zipping ub b -> Zipping (TPair ua ub) (TPair a b)
-- ZZip :: Zipping ua (TArr n a) -> Zipping ub (TArr n b) -> Zipping (TPair ua ub) (TArr n (TPair a b))
-- deriving instance Show (Zipping ut t)
|