1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module CHAD.Fusion where
import Data.Dependent.Map (DMap)
-- import Data.Dependent.Map qualified as DMap
import Data.Functor.Const
import Data.Kind (Type)
import Data.Some
import Numeric.Natural
import CHAD.AST
import CHAD.AST.Bindings
import CHAD.AST.Count
import CHAD.AST.Env
import CHAD.Data
import CHAD.Lemmas
-- TODO:
-- A bunch of data types are defined here that should be able to express a
-- graph of loop nests. A graph is a straight-line program whose statements
-- are, in this case, loop nests. A loop nest corresponds to what fusion
-- normally calls a "cluster", but is here represented as, well, a loop nest.
--
-- No unzipping is done here, as I don't think it is necessary: I haven't been
-- able to think of programs that get more fusion opportunities when unzipped
-- than when zipped. If any such programs exist, I in any case conjecture that
-- with a pre-pass that splits array operations that can be unzipped already at
-- the source-level (e.g. build n (\i -> (E1, E2)) -> zip (build n (\i -> E1),
-- build n (\i -> E2))), all such fusion opportunities can be recovered. If
-- this conjecture is false, some reorganisation may be required.
--
-- Next steps, perhaps:
-- 1. Express a build operation as a LoopNest, not from the EBuild constructor
-- specifically but its fields. It will have a single output, and its args
-- will be its list of free variables. DONE
-- 1a. Write a pretty-printer for LoopNest lol
-- 2. Express a sum operation as a LoopNest in the same way; 1 arg, 1 out.
-- 3. Write a "recognition" pass that eagerly constructs graphs for subterms of
-- a large expression that contain only "simple" AST constructors, and
-- replaces those subterms with an EExt containing that graph. In this
-- construction process, EBuild and ESum1Inner should be replaced with
-- FLoop.
-- 4. Implement fusion somehow on graphs!
-- 5. Add an AST constructor for a loop nest (which most of the modules throw
-- an error on, except Count, Simplify and Compile -- good luck with Count),
-- and compile that to an actual C loop nest.
-- 6. Extend to other cool operations like EFold1InnerD1
type FEx = Expr FGraph (Const ())
type FGraph :: (Ty -> Type) -> [Ty] -> Ty -> Type
data FGraph x env t where
FGraph :: DMap NodeId (Node env) -> Tuple NodeId t -> FGraph (Const ()) env t
data Node env t where
FFreeVar :: STy t -> Idx env t -> Node env t
FLoop :: SList NodeId args
-> SList STy outs
-> LoopNest args outs
-> Tuple (Idx outs) t
-> Node env t
data NodeId t = NodeId Natural (STy t)
deriving (Show)
data Tuple f t where
TupNil :: Tuple f TNil
TupPair :: Tuple f a -> Tuple f b -> Tuple f (TPair a b)
TupSingle :: f t -> Tuple f t
deriving instance (forall a. Show (f a)) => Show (Tuple f t)
data LoopNest args outs where
Inner :: Bindings Ex args bs
-> SList (Idx (Append bs args)) outs
-> LoopNest args outs
-- this should be able to express a simple nesting of builds and sums.
Layer :: Bindings Ex args bs1
-> Idx bs1 TIx -- ^ loop width (number of (parallel) iterations)
-> LoopNest (TIx : Append bs1 args) loopouts
-> Partition BuildUp RedSum loopouts mapouts sumouts
-> Bindings Ex (Append sumouts (Append bs1 args)) bs2
-> SList (Idx (Append bs2 (Append bs1 args))) outs
-> LoopNest args (Append outs mapouts)
deriving instance Show (LoopNest args outs)
type Partition :: (Ty -> Ty -> Type) -> (Ty -> Ty -> Type) -> [Ty] -> [Ty] -> [Ty] -> Type
data Partition f1 f2 ts ts1 ts2 where
PNil :: Partition f1 f2 '[] '[] '[]
Part1 :: f1 t t1 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) (t1 : ts1) ts2
Part2 :: f2 t t2 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) ts1 (t2 : ts2)
deriving instance (forall t t1. Show (f1 t t1), forall t t2. Show (f2 t t2)) => Show (Partition f1 f2 ts ts1 ts2)
data BuildUp t t' where
BuildUp :: SNat n -> STy t -> BuildUp (TArr n t) (TArr (S n) t)
deriving instance Show (BuildUp t t')
data RedSum t t' where
RedSum :: SMTy t -> RedSum t t
deriving instance Show (RedSum t t')
-- type family Unzip t where
-- Unzip (TPair a b) = TPair (Unzip a) (Unzip b)
-- Unzip (TArr n t) = UnzipA n t
-- type family UnzipA n t where
-- UnzipA n (TPair a b) = TPair (UnzipA n a) (UnzipA n b)
-- UnzipA n t = TArr n t
-- data Zipping ut t where
-- ZId :: Zipping t t
-- ZPair :: Zipping ua a -> Zipping ub b -> Zipping (TPair ua ub) (TPair a b)
-- ZZip :: Zipping ua (TArr n a) -> Zipping ub (TArr n b) -> Zipping (TPair ua ub) (TArr n (TPair a b))
-- deriving instance Show (Zipping ut t)
prependBinding :: forall args outs t. Ex args t -> LoopNest (t : args) outs -> LoopNest args outs
prependBinding e (Inner (bs :: Bindings Ex (t : args) bs) outs)
| Refl <- lemAppendAssoc @bs @'[t] @args
= Inner (bconcat (BTop `bpush` e) bs) outs
prependBinding e (Layer (bs1 :: Bindings Ex (t : args) bs1) wid nest part bs2 outs)
| Refl <- lemAppendAssoc @bs1 @'[t] @args
, Refl <- lemAppendNil @bs1
= Layer (bconcat (BTop `bpush` e) bs1)
(wCopies (bindingsBinds bs1) (WSink @t @'[]) @> wid)
nest part bs2 outs
buildLoopNest :: SList STy env -> SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex (Tup (Replicate n TIx) : env) t
-> (forall args. Subenv env args -> LoopNest args '[TArr n t] -> r) -> r
buildLoopNest = \env sn esh ebody k ->
withSome (occCountAll ebody) $ \occBody' ->
occEnvPop' occBody' $ \occBody ->
withSome (occCountAll esh <> Some occBody) $ \occ _ ->
deleteUnused env (Some occ) $ \deleteSub ->
let esh' = unsafeWeakenWithSubenv deleteSub esh
ebody' = unsafeWeakenWithSubenv (SEYesR deleteSub) ebody
in k deleteSub $
prependBinding esh' $
nestMapN sn (typeOf ebody)
(unTup (\e -> (EFst ext e, ESnd ext e)) (sreplicate sn tIx) (EVar ext (tTup (sreplicate sn tIx)) IZ)) $ \w idx ->
Inner (BTop `bpush` elet idx (EUnit ext (weakenExpr (WCopy (WPop w)) ebody'))) (IZ `SCons` SNil)
where
nestMapN :: SNat n -> STy t -> SList (Ex args) (Replicate n TIx)
-> (forall args'. args :> args' -> Ex args' (Tup (Replicate n TIx)) -> LoopNest args' '[TArr Z t])
-> LoopNest args '[TArr n t]
nestMapN SZ _ SNil inner = inner WId (ENil ext)
nestMapN (SS sn) ty (wid `SCons` sh) inner =
Layer (BTop `bpush` wid)
IZ
(nestMapN sn ty (slistMap (weakenExpr (WSink .> WSink)) sh) $ \w idx ->
inner (w .> WSink .> WSink) (EPair ext idx (EVar ext tIx (w @> IZ))))
(Part1 (BuildUp sn ty) PNil)
BTop
SNil
|