From fd729684a0f15dab7d0b97df3da65718bd5fbba9 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 23 Nov 2025 23:27:20 +0100 Subject: fusion: buildLoopNest --- src/CHAD/Fusion.hs | 55 +++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 52 insertions(+), 3 deletions(-) (limited to 'src/CHAD/Fusion.hs') diff --git a/src/CHAD/Fusion.hs b/src/CHAD/Fusion.hs index 757667f..3358d30 100644 --- a/src/CHAD/Fusion.hs +++ b/src/CHAD/Fusion.hs @@ -2,8 +2,11 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module CHAD.Fusion where @@ -12,11 +15,15 @@ import Data.Dependent.Map (DMap) -- import Data.Dependent.Map qualified as DMap import Data.Functor.Const import Data.Kind (Type) +import Data.Some import Numeric.Natural import CHAD.AST import CHAD.AST.Bindings +import CHAD.AST.Count +import CHAD.AST.Env import CHAD.Data +import CHAD.Lemmas -- TODO: @@ -36,7 +43,8 @@ import CHAD.Data -- Next steps, perhaps: -- 1. Express a build operation as a LoopNest, not from the EBuild constructor -- specifically but its fields. It will have a single output, and its args --- will be its list of free variables. +-- will be its list of free variables. DONE +-- 1a. Write a pretty-printer for LoopNest lol -- 2. Express a sum operation as a LoopNest in the same way; 1 arg, 1 out. -- 3. Write a "recognition" pass that eagerly constructs graphs for subterms of -- a large expression that contain only "simple" AST constructors, and @@ -83,20 +91,24 @@ data LoopNest args outs where -> LoopNest (TIx : Append bs1 args) loopouts -> Partition BuildUp RedSum loopouts mapouts sumouts -> Bindings Ex (Append sumouts (Append bs1 args)) bs2 - -> SList (Idx (Append bs2 args)) outs + -> SList (Idx (Append bs2 (Append bs1 args))) outs -> LoopNest args (Append outs mapouts) +deriving instance Show (LoopNest args outs) type Partition :: (Ty -> Ty -> Type) -> (Ty -> Ty -> Type) -> [Ty] -> [Ty] -> [Ty] -> Type data Partition f1 f2 ts ts1 ts2 where PNil :: Partition f1 f2 '[] '[] '[] Part1 :: f1 t t1 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) (t1 : ts1) ts2 Part2 :: f2 t t2 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) ts1 (t2 : ts2) +deriving instance (forall t t1. Show (f1 t t1), forall t t2. Show (f2 t t2)) => Show (Partition f1 f2 ts ts1 ts2) data BuildUp t t' where BuildUp :: SNat n -> STy t -> BuildUp (TArr n t) (TArr (S n) t) +deriving instance Show (BuildUp t t') data RedSum t t' where RedSum :: SMTy t -> RedSum t t +deriving instance Show (RedSum t t') -- type family Unzip t where -- Unzip (TPair a b) = TPair (Unzip a) (Unzip b) @@ -112,4 +124,41 @@ data RedSum t t' where -- ZZip :: Zipping ua (TArr n a) -> Zipping ub (TArr n b) -> Zipping (TPair ua ub) (TArr n (TPair a b)) -- deriving instance Show (Zipping ut t) - +prependBinding :: forall args outs t. Ex args t -> LoopNest (t : args) outs -> LoopNest args outs +prependBinding e (Inner (bs :: Bindings Ex (t : args) bs) outs) + | Refl <- lemAppendAssoc @bs @'[t] @args + = Inner (bconcat (BTop `bpush` e) bs) outs +prependBinding e (Layer (bs1 :: Bindings Ex (t : args) bs1) wid nest part bs2 outs) + | Refl <- lemAppendAssoc @bs1 @'[t] @args + , Refl <- lemAppendNil @bs1 + = Layer (bconcat (BTop `bpush` e) bs1) + (wCopies (bindingsBinds bs1) (WSink @t @'[]) @> wid) + nest part bs2 outs + +buildLoopNest :: SList STy env -> SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex (Tup (Replicate n TIx) : env) t + -> (forall args. Subenv env args -> LoopNest args '[TArr n t] -> r) -> r +buildLoopNest = \env sn esh ebody k -> + withSome (occCountAll ebody) $ \occBody' -> + occEnvPop' occBody' $ \occBody -> + withSome (occCountAll esh <> Some occBody) $ \occ _ -> + deleteUnused env (Some occ) $ \deleteSub -> + let esh' = unsafeWeakenWithSubenv deleteSub esh + ebody' = unsafeWeakenWithSubenv (SEYesR deleteSub) ebody + in k deleteSub $ + prependBinding esh' $ + nestMapN sn (typeOf ebody) + (unTup (\e -> (EFst ext e, ESnd ext e)) (sreplicate sn tIx) (EVar ext (tTup (sreplicate sn tIx)) IZ)) $ \w idx -> + Inner (BTop `bpush` elet idx (EUnit ext (weakenExpr (WCopy (WPop w)) ebody'))) (IZ `SCons` SNil) + where + nestMapN :: SNat n -> STy t -> SList (Ex args) (Replicate n TIx) + -> (forall args'. args :> args' -> Ex args' (Tup (Replicate n TIx)) -> LoopNest args' '[TArr Z t]) + -> LoopNest args '[TArr n t] + nestMapN SZ _ SNil inner = inner WId (ENil ext) + nestMapN (SS sn) ty (wid `SCons` sh) inner = + Layer (BTop `bpush` wid) + IZ + (nestMapN sn ty (slistMap (weakenExpr (WSink .> WSink)) sh) $ \w idx -> + inner (w .> WSink .> WSink) (EPair ext idx (EVar ext tIx (w @> IZ)))) + (Part1 (BuildUp sn ty) PNil) + BTop + SNil -- cgit v1.2.3-70-g09d2