{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module CHAD.Fusion where -- import Data.Dependent.Map qualified as DMap import Data.Some import CHAD.AST import CHAD.AST.Bindings import CHAD.AST.Count import CHAD.AST.Env import CHAD.Data import CHAD.Fusion.AST import CHAD.Lemmas -- TODO: -- A bunch of data types are defined here (in CHAD.Fusion.AST) that should be -- able to express a graph of loop nests. A graph is a straight-line program -- whose statements are, in this case, loop nests. A loop nest corresponds to -- what fusion normally calls a "cluster", but is here represented as, well, a -- loop nest. -- -- No unzipping is done here, as I don't think it is necessary: I haven't been -- able to think of programs that get more fusion opportunities when unzipped -- than when zipped. If any such programs exist, I in any case conjecture that -- with a pre-pass that splits array operations that can be unzipped already at -- the source-level (e.g. build n (\i -> (E1, E2)) -> zip (build n (\i -> E1), -- build n (\i -> E2))), all such fusion opportunities can be recovered. If -- this conjecture is false, some reorganisation may be required. -- -- Next steps, perhaps: -- 1. Express a build operation as a LoopNest, not from the EBuild constructor -- specifically but its fields. It will have a single output, and its args -- will be its list of free variables. DONE -- 1a. Write a pretty-printer for LoopNest lol -- 2. Express a sum operation as a LoopNest in the same way; 1 arg, 1 out. -- 3. Write a "recognition" pass that eagerly constructs graphs for subterms of -- a large expression that contain only "simple" AST constructors, and -- replaces those subterms with an EExt containing that graph. In this -- construction process, EBuild and ESum1Inner should be replaced with -- FLoop. -- 4. Implement fusion somehow on graphs! -- 5. Add an AST constructor for a loop nest (which most of the modules throw -- an error on, except Count, Simplify and Compile -- good luck with Count), -- and compile that to an actual C loop nest. -- 6. Extend to other cool operations like EFold1InnerD1 -- :m *CHAD.Fusion CHAD.AST.Pretty CHAD.Language -- pprintExpr $ fromNamed $ body $ build (SS (SS SZ)) (pair (pair nil 3) 4) (#idx :-> snd_ #idx + snd_ (fst_ #idx)) -- putStrLn $ case fromNamed $ body $ build (SS (SS SZ)) (pair (pair nil 3) 4) (#idx :-> snd_ #idx + snd_ (fst_ #idx)) of EBuild _ n esh ebody -> let env = knownEnv in buildLoopNest env n esh ebody $ \sub nest -> show sub ++ "\n" ++ ppLoopNest (subList env sub) nest prependBinding :: forall args outs bouts t. Ex args t -> LoopNest (t : args) outs bouts -> LoopNest args outs bouts prependBinding e (Inner (bs :: Bindings Ex (t : args) bs) outs) | Refl <- lemAppendAssoc @bs @'[t] @args = Inner (bconcat (BTop `bpush` e) bs) outs prependBinding e (Layer (bs1 :: Bindings Ex (t : args) bs1) wid nest part bs2 outs) | Refl <- lemAppendAssoc @bs1 @'[t] @args , Refl <- lemAppendNil @bs1 = Layer (bconcat (BTop `bpush` e) bs1) (wCopies (bindingsBinds bs1) (WSink @t @'[]) @> wid) nest part bs2 outs buildLoopNest :: SList STy env -> SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex (Tup (Replicate n TIx) : env) t -> (forall args. Subenv env args -> LoopNest args '[] '[TArr n t] -> r) -> r buildLoopNest = \env sn esh ebody k -> withSome (occCountAll ebody) $ \occBody' -> occEnvPop' occBody' $ \occBody -> withSome (occCountAll esh <> Some occBody) $ \occ _ -> deleteUnused env (Some occ) $ \deleteSub -> let esh' = unsafeWeakenWithSubenv deleteSub esh ebody' = unsafeWeakenWithSubenv (SEYesR deleteSub) ebody in k deleteSub $ prependBinding esh' $ nestMapN sn (typeOf ebody) (unTup (\e -> (EFst ext e, ESnd ext e)) (sreplicate sn tIx) (EVar ext (tTup (sreplicate sn tIx)) IZ)) $ \w idx -> Inner (BTop `bpush` elet idx (EUnit ext (weakenExpr (WCopy (WPop w)) ebody'))) (IZ `SCons` SNil) where nestMapN :: SNat n -> STy t -> SList (Ex args) (Replicate n TIx) -> (forall args'. args :> args' -> Ex args' (Tup (Replicate n TIx)) -> LoopNest args' '[TArr Z t] '[]) -> LoopNest args '[] '[TArr n t] nestMapN SZ _ SNil inner = inner WId (ENil ext) nestMapN (SS sn) ty (wid `SCons` sh) inner = Layer (BTop `bpush` wid) IZ (nestMapN sn ty (slistMap (weakenExpr (WSink .> WSink)) sh) $ \w idx -> inner (w .> WSink .> WSink) (EPair ext idx (EVar ext tIx (w @> IZ)))) (Part1 (BuildUp sn ty) PNil) BTop SNil