aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/Fusion.hs
blob: 29c1f12147644b5c5ba98612601f05276da5f16e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module CHAD.Fusion where

-- import Data.Dependent.Map qualified as DMap
import Data.Some

import CHAD.AST
import CHAD.AST.Bindings
import CHAD.AST.Count
import CHAD.AST.Env
import CHAD.Data
import CHAD.Fusion.AST
import CHAD.Lemmas


-- TODO:
-- A bunch of data types are defined here (in CHAD.Fusion.AST) that should be
-- able to express a graph of loop nests. A graph is a straight-line program
-- whose statements are, in this case, loop nests. A loop nest corresponds to
-- what fusion normally calls a "cluster", but is here represented as, well, a
-- loop nest.
--
-- No unzipping is done here, as I don't think it is necessary: I haven't been
-- able to think of programs that get more fusion opportunities when unzipped
-- than when zipped. If any such programs exist, I in any case conjecture that
-- with a pre-pass that splits array operations that can be unzipped already at
-- the source-level (e.g. build n (\i -> (E1, E2)) -> zip (build n (\i -> E1),
-- build n (\i -> E2))), all such fusion opportunities can be recovered. If
-- this conjecture is false, some reorganisation may be required.
--
-- Next steps, perhaps:
-- 1. Express a build operation as a LoopNest, not from the EBuild constructor
--    specifically but its fields. It will have a single output, and its args
--    will be its list of free variables. DONE
-- 1a. Write a pretty-printer for LoopNest lol
-- 2. Express a sum operation as a LoopNest in the same way; 1 arg, 1 out.
-- 3. Write a "recognition" pass that eagerly constructs graphs for subterms of
--    a large expression that contain only "simple" AST constructors, and
--    replaces those subterms with an EExt containing that graph. In this
--    construction process, EBuild and ESum1Inner should be replaced with
--    FLoop.
-- 4. Implement fusion somehow on graphs!
-- 5. Add an AST constructor for a loop nest (which most of the modules throw
--    an error on, except Count, Simplify and Compile -- good luck with Count),
--    and compile that to an actual C loop nest.
-- 6. Extend to other cool operations like EFold1InnerD1

-- :m *CHAD.Fusion CHAD.AST.Pretty CHAD.Language
-- pprintExpr $ fromNamed $ body $ build (SS (SS SZ)) (pair (pair nil 3) 4) (#idx :-> snd_ #idx + snd_ (fst_ #idx))
-- putStrLn $ case fromNamed $ body $ build (SS (SS SZ)) (pair (pair nil 3) 4) (#idx :-> snd_ #idx + snd_ (fst_ #idx)) of EBuild _ n esh ebody -> let env = knownEnv in buildLoopNest env n esh ebody $ \sub nest -> show sub ++ "\n" ++ ppLoopNest (subList env sub) nest


prependBinding :: forall args outs bouts t. Ex args t -> LoopNest (t : args) outs bouts -> LoopNest args outs bouts
prependBinding e (Inner (bs :: Bindings Ex (t : args) bs) outs)
  | Refl <- lemAppendAssoc @bs @'[t] @args
  = Inner (bconcat (BTop `bpush` e) bs) outs
prependBinding e (Layer (bs1 :: Bindings Ex (t : args) bs1) wid nest part bs2 outs)
  | Refl <- lemAppendAssoc @bs1 @'[t] @args
  , Refl <- lemAppendNil @bs1
  = Layer (bconcat (BTop `bpush` e) bs1)
          (wCopies (bindingsBinds bs1) (WSink @t @'[]) @> wid)
          nest part bs2 outs

buildLoopNest :: SList STy env -> SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex (Tup (Replicate n TIx) : env) t
              -> (forall args. Subenv env args -> LoopNest args '[] '[TArr n t] -> r) -> r
buildLoopNest = \env sn esh ebody k ->
  withSome (occCountAll ebody) $ \occBody' ->
  occEnvPop' occBody' $ \occBody ->
  withSome (occCountAll esh <> Some occBody) $ \occ _ ->
  deleteUnused env (Some occ) $ \deleteSub ->
  let esh' = unsafeWeakenWithSubenv deleteSub esh
      ebody' = unsafeWeakenWithSubenv (SEYesR deleteSub) ebody
  in k deleteSub $
       prependBinding esh' $
         nestMapN sn (typeOf ebody)
                  (unTup (\e -> (EFst ext e, ESnd ext e)) (sreplicate sn tIx) (EVar ext (tTup (sreplicate sn tIx)) IZ)) $ \w idx ->
           Inner (BTop `bpush` elet idx (EUnit ext (weakenExpr (WCopy (WPop w)) ebody'))) (IZ `SCons` SNil)
  where
    nestMapN :: SNat n -> STy t -> SList (Ex args) (Replicate n TIx)
             -> (forall args'. args :> args' -> Ex args' (Tup (Replicate n TIx)) -> LoopNest args' '[TArr Z t] '[])
             -> LoopNest args '[] '[TArr n t]
    nestMapN SZ _ SNil inner = inner WId (ENil ext)
    nestMapN (SS sn) ty (wid `SCons` sh) inner =
      Layer (BTop `bpush` wid)
            IZ
            (nestMapN sn ty (slistMap (weakenExpr (WSink .> WSink)) sh) $ \w idx ->
                inner (w .> WSink .> WSink) (EPair ext idx (EVar ext tIx (w @> IZ))))
            (Part1 (BuildUp sn ty) PNil)
            BTop
            SNil