aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/Fusion.hs
blob: 757667ff85d1d1dbdfbc01605a510360aa10a4f6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module CHAD.Fusion where

import Data.Dependent.Map (DMap)
-- import Data.Dependent.Map qualified as DMap
import Data.Functor.Const
import Data.Kind (Type)
import Numeric.Natural

import CHAD.AST
import CHAD.AST.Bindings
import CHAD.Data


-- TODO:
-- A bunch of data types are defined here that should be able to express a
-- graph of loop nests. A graph is a straight-line program whose statements
-- are, in this case, loop nests. A loop nest corresponds to what fusion
-- normally calls a "cluster", but is here represented as, well, a loop nest.
--
-- No unzipping is done here, as I don't think it is necessary: I haven't been
-- able to think of programs that get more fusion opportunities when unzipped
-- than when zipped. If any such programs exist, I in any case conjecture that
-- with a pre-pass that splits array operations that can be unzipped already at
-- the source-level (e.g. build n (\i -> (E1, E2)) -> zip (build n (\i -> E1),
-- build n (\i -> E2))), all such fusion opportunities can be recovered. If
-- this conjecture is false, some reorganisation may be required.
--
-- Next steps, perhaps:
-- 1. Express a build operation as a LoopNest, not from the EBuild constructor
--    specifically but its fields. It will have a single output, and its args
--    will be its list of free variables.
-- 2. Express a sum operation as a LoopNest in the same way; 1 arg, 1 out.
-- 3. Write a "recognition" pass that eagerly constructs graphs for subterms of
--    a large expression that contain only "simple" AST constructors, and
--    replaces those subterms with an EExt containing that graph. In this
--    construction process, EBuild and ESum1Inner should be replaced with
--    FLoop.
-- 4. Implement fusion somehow on graphs!
-- 5. Add an AST constructor for a loop nest (which most of the modules throw
--    an error on, except Count, Simplify and Compile -- good luck with Count),
--    and compile that to an actual C loop nest.
-- 6. Extend to other cool operations like EFold1InnerD1


type FEx = Expr FGraph (Const ())

type FGraph :: (Ty -> Type) -> [Ty] -> Ty -> Type
data FGraph x env t where
  FGraph :: DMap NodeId (Node env) -> Tuple NodeId t -> FGraph (Const ()) env t

data Node env t where
  FFreeVar :: STy t -> Idx env t -> Node env t
  FLoop :: SList NodeId args
        -> SList STy outs
        -> LoopNest args outs
        -> Tuple (Idx outs) t
        -> Node env t

data NodeId t = NodeId Natural (STy t)
  deriving (Show)

data Tuple f t where
  TupNil :: Tuple f TNil
  TupPair :: Tuple f a -> Tuple f b -> Tuple f (TPair a b)
  TupSingle :: f t -> Tuple f t
deriving instance (forall a. Show (f a)) => Show (Tuple f t)

data LoopNest args outs where
  Inner :: Bindings Ex args bs
        -> SList (Idx (Append bs args)) outs
        -> LoopNest args outs
  -- this should be able to express a simple nesting of builds and sums.
  Layer :: Bindings Ex args bs1
        -> Idx bs1 TIx  -- ^ loop width (number of (parallel) iterations)
        -> LoopNest (TIx : Append bs1 args) loopouts
        -> Partition BuildUp RedSum loopouts mapouts sumouts
        -> Bindings Ex (Append sumouts (Append bs1 args)) bs2
        -> SList (Idx (Append bs2 args)) outs
        -> LoopNest args (Append outs mapouts)

type Partition :: (Ty -> Ty -> Type) -> (Ty -> Ty -> Type) -> [Ty] -> [Ty] -> [Ty] -> Type
data Partition f1 f2 ts ts1 ts2 where
  PNil :: Partition f1 f2 '[] '[] '[]
  Part1 :: f1 t t1 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) (t1 : ts1) ts2
  Part2 :: f2 t t2 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) ts1 (t2 : ts2)

data BuildUp t t' where
  BuildUp :: SNat n -> STy t -> BuildUp (TArr n t) (TArr (S n) t)

data RedSum t t' where
  RedSum :: SMTy t -> RedSum t t

-- type family Unzip t where
--   Unzip (TPair a b) = TPair (Unzip a) (Unzip b)
--   Unzip (TArr n t) = UnzipA n t

-- type family UnzipA n t where
--   UnzipA n (TPair a b) = TPair (UnzipA n a) (UnzipA n b)
--   UnzipA n t = TArr n t

-- data Zipping ut t where
--   ZId :: Zipping t t
--   ZPair :: Zipping ua a -> Zipping ub b -> Zipping (TPair ua ub) (TPair a b)
--   ZZip :: Zipping ua (TArr n a) -> Zipping ub (TArr n b) -> Zipping (TPair ua ub) (TArr n (TPair a b))
-- deriving instance Show (Zipping ut t)