aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-23 23:27:20 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-23 23:27:20 +0100
commitfd729684a0f15dab7d0b97df3da65718bd5fbba9 (patch)
tree33cfd64157273336bc5b32165650c5a5a0ddea85 /src/CHAD
parent9b7c3eea7e34f5eb0d91f93b803e853028c2cec8 (diff)
fusion: buildLoopNestfusion
Diffstat (limited to 'src/CHAD')
-rw-r--r--src/CHAD/Fusion.hs53
1 files changed, 51 insertions, 2 deletions
diff --git a/src/CHAD/Fusion.hs b/src/CHAD/Fusion.hs
index 757667f..3358d30 100644
--- a/src/CHAD/Fusion.hs
+++ b/src/CHAD/Fusion.hs
@@ -2,8 +2,11 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module CHAD.Fusion where
@@ -12,11 +15,15 @@ import Data.Dependent.Map (DMap)
-- import Data.Dependent.Map qualified as DMap
import Data.Functor.Const
import Data.Kind (Type)
+import Data.Some
import Numeric.Natural
import CHAD.AST
import CHAD.AST.Bindings
+import CHAD.AST.Count
+import CHAD.AST.Env
import CHAD.Data
+import CHAD.Lemmas
-- TODO:
@@ -36,7 +43,8 @@ import CHAD.Data
-- Next steps, perhaps:
-- 1. Express a build operation as a LoopNest, not from the EBuild constructor
-- specifically but its fields. It will have a single output, and its args
--- will be its list of free variables.
+-- will be its list of free variables. DONE
+-- 1a. Write a pretty-printer for LoopNest lol
-- 2. Express a sum operation as a LoopNest in the same way; 1 arg, 1 out.
-- 3. Write a "recognition" pass that eagerly constructs graphs for subterms of
-- a large expression that contain only "simple" AST constructors, and
@@ -83,20 +91,24 @@ data LoopNest args outs where
-> LoopNest (TIx : Append bs1 args) loopouts
-> Partition BuildUp RedSum loopouts mapouts sumouts
-> Bindings Ex (Append sumouts (Append bs1 args)) bs2
- -> SList (Idx (Append bs2 args)) outs
+ -> SList (Idx (Append bs2 (Append bs1 args))) outs
-> LoopNest args (Append outs mapouts)
+deriving instance Show (LoopNest args outs)
type Partition :: (Ty -> Ty -> Type) -> (Ty -> Ty -> Type) -> [Ty] -> [Ty] -> [Ty] -> Type
data Partition f1 f2 ts ts1 ts2 where
PNil :: Partition f1 f2 '[] '[] '[]
Part1 :: f1 t t1 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) (t1 : ts1) ts2
Part2 :: f2 t t2 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) ts1 (t2 : ts2)
+deriving instance (forall t t1. Show (f1 t t1), forall t t2. Show (f2 t t2)) => Show (Partition f1 f2 ts ts1 ts2)
data BuildUp t t' where
BuildUp :: SNat n -> STy t -> BuildUp (TArr n t) (TArr (S n) t)
+deriving instance Show (BuildUp t t')
data RedSum t t' where
RedSum :: SMTy t -> RedSum t t
+deriving instance Show (RedSum t t')
-- type family Unzip t where
-- Unzip (TPair a b) = TPair (Unzip a) (Unzip b)
@@ -112,4 +124,41 @@ data RedSum t t' where
-- ZZip :: Zipping ua (TArr n a) -> Zipping ub (TArr n b) -> Zipping (TPair ua ub) (TArr n (TPair a b))
-- deriving instance Show (Zipping ut t)
+prependBinding :: forall args outs t. Ex args t -> LoopNest (t : args) outs -> LoopNest args outs
+prependBinding e (Inner (bs :: Bindings Ex (t : args) bs) outs)
+ | Refl <- lemAppendAssoc @bs @'[t] @args
+ = Inner (bconcat (BTop `bpush` e) bs) outs
+prependBinding e (Layer (bs1 :: Bindings Ex (t : args) bs1) wid nest part bs2 outs)
+ | Refl <- lemAppendAssoc @bs1 @'[t] @args
+ , Refl <- lemAppendNil @bs1
+ = Layer (bconcat (BTop `bpush` e) bs1)
+ (wCopies (bindingsBinds bs1) (WSink @t @'[]) @> wid)
+ nest part bs2 outs
+buildLoopNest :: SList STy env -> SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex (Tup (Replicate n TIx) : env) t
+ -> (forall args. Subenv env args -> LoopNest args '[TArr n t] -> r) -> r
+buildLoopNest = \env sn esh ebody k ->
+ withSome (occCountAll ebody) $ \occBody' ->
+ occEnvPop' occBody' $ \occBody ->
+ withSome (occCountAll esh <> Some occBody) $ \occ _ ->
+ deleteUnused env (Some occ) $ \deleteSub ->
+ let esh' = unsafeWeakenWithSubenv deleteSub esh
+ ebody' = unsafeWeakenWithSubenv (SEYesR deleteSub) ebody
+ in k deleteSub $
+ prependBinding esh' $
+ nestMapN sn (typeOf ebody)
+ (unTup (\e -> (EFst ext e, ESnd ext e)) (sreplicate sn tIx) (EVar ext (tTup (sreplicate sn tIx)) IZ)) $ \w idx ->
+ Inner (BTop `bpush` elet idx (EUnit ext (weakenExpr (WCopy (WPop w)) ebody'))) (IZ `SCons` SNil)
+ where
+ nestMapN :: SNat n -> STy t -> SList (Ex args) (Replicate n TIx)
+ -> (forall args'. args :> args' -> Ex args' (Tup (Replicate n TIx)) -> LoopNest args' '[TArr Z t])
+ -> LoopNest args '[TArr n t]
+ nestMapN SZ _ SNil inner = inner WId (ENil ext)
+ nestMapN (SS sn) ty (wid `SCons` sh) inner =
+ Layer (BTop `bpush` wid)
+ IZ
+ (nestMapN sn ty (slistMap (weakenExpr (WSink .> WSink)) sh) $ \w idx ->
+ inner (w .> WSink .> WSink) (EPair ext idx (EVar ext tIx (w @> IZ))))
+ (Part1 (BuildUp sn ty) PNil)
+ BTop
+ SNil