aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-26 22:00:39 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-26 22:00:39 +0100
commitd5ea985f9d252af55ea0a5c3f00374a41b562369 (patch)
treea0663aac094b02c60935b6e651e4dd38fac99959
parentd74a7b212f06fbfad1b7f578cb127613acfb3311 (diff)
WIP stuff
-rw-r--r--chad-fast.cabal1
-rw-r--r--src/CHAD/AST/Pretty.hs96
-rw-r--r--src/CHAD/Fusion.hs82
-rw-r--r--src/CHAD/Fusion/AST.hs83
4 files changed, 180 insertions, 82 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index 834f1d7..6a9147a 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -45,6 +45,7 @@ library
CHAD.ForwardAD.DualNumbers
CHAD.ForwardAD.DualNumbers.Types
CHAD.Fusion
+ CHAD.Fusion.AST
CHAD.Interpreter
-- CHAD.Interpreter.AccumOld
CHAD.Interpreter.Rep
diff --git a/src/CHAD/AST/Pretty.hs b/src/CHAD/AST/Pretty.hs
index b763efe..a9a8987 100644
--- a/src/CHAD/AST/Pretty.hs
+++ b/src/CHAD/AST/Pretty.hs
@@ -5,10 +5,17 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TupleSections #-}
-module CHAD.AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where
+module CHAD.AST.Pretty (
+ pprintExpr,
+ ppExpr,
+ ppLoopNest,
+ ppSTy, ppSMTy,
+ PrettyX(..),
+) where
import Control.Monad (ap)
import Data.List (intersperse, intercalate)
+import Data.Foldable (toList)
import Data.Functor.Const
import Data.Functor.Product qualified as Product
import Data.String (fromString)
@@ -22,10 +29,12 @@ import System.IO (stdout)
import System.IO.Unsafe (unsafePerformIO)
import CHAD.AST
+import CHAD.AST.Bindings
import CHAD.AST.Count
import CHAD.AST.Sparse.Types
import CHAD.Data
import CHAD.Drev.Types
+import CHAD.Fusion.AST
class PrettyX x where
@@ -78,7 +87,7 @@ pprintExpr = putStrLn . ppExpr knownEnv
ppExpr :: PrettyX x => SList STy env -> Expr NoExt x env t -> String
ppExpr senv e = render $ fst . flip runM 1 $ do
- val <- mkVal senv
+ val <- mkSVal senv
e' <- ppExpr' 0 val e
let lam = "λ" ++ intercalate " " (reverse (unSList (\(Product.Pair (Const name) ty) -> "(" ++ name ++ " : " ++ ppSTy 0 ty ++ ")") (slistZip val senv))) ++ "."
return $ group $ flatAlt
@@ -86,13 +95,13 @@ ppExpr senv e = render $ fst . flip runM 1 $ do
ppString lam
<> hardline <> e')
(ppString lam <+> e')
- where
- mkVal :: SList f env -> M (SVal env)
- mkVal SNil = return SNil
- mkVal (SCons _ v) = do
- val <- mkVal v
- name <- genName' "arg"
- return (Const name `SCons` val)
+
+mkSVal :: SList f env -> M (SVal env)
+mkSVal SNil = return SNil
+mkSVal (SCons _ v) = do
+ val <- mkSVal v
+ name <- genName' "arg"
+ return (Const name `SCons` val)
ppExpr' :: PrettyX x => Int -> SVal env -> Expr NoExt x env t -> M ADoc
ppExpr' d val expr = case expr of
@@ -403,6 +412,75 @@ ppLam :: [ADoc] -> ADoc -> ADoc
ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"])
<> softline <> body <> ppString ")")
+ppLoopNest :: SList f args -> LoopNest args outs -> String
+ppLoopNest senv lnest = render $ fst . flip runM 1 $ do
+ val <- mkSVal senv
+ ppLoopNest' val lnest
+
+data RedKind = RKRet | RKBuild | RKSum
+
+ppLoopNest' :: SVal args -> LoopNest args outs -> M ADoc
+ppLoopNest' = \env lnest -> do
+ (f, outs) <- go env lnest
+ return (f (slistMap (\(Const _) -> Const RKRet) outs))
+ where
+ go :: SVal args -> LoopNest args outs -> M (SList (Const RedKind) outs -> ADoc, SVal outs)
+ go env (Inner bs outs) = do
+ (bs', names) <- goBindings env bs
+ let outNames = slistMap (\i -> slistIdx (sappend names env) i) outs
+ outDoc kinds = ppString "["
+ <> mconcat (map ppString (intersperse ", " (zipWith decorate kinds (unSList getConst outNames))))
+ <> ppString "]"
+ return (\kinds ->
+ vcat (toList bs')
+ <> hardline <> (annotate AKey (ppString "ret") <+> outDoc (unSList getConst kinds))
+ ,outNames)
+ go env (Layer bs1 wid lnest part bs2 outs) = do
+ (bs1', names1) <- goBindings env bs1
+ widname <- genName' "i"
+ (f, loopouts) <- go (Const widname `SCons` sappend names1 env) lnest
+ let (redkinds, mapouts, sumouts) = partition part loopouts
+ let lnest' = f redkinds
+ (bs2', names2) <- goBindings (sappend sumouts (sappend names1 env)) bs2
+ let outNames = slistMap (\i -> slistIdx (sappend names2 (sappend names1 env)) i) outs
+ outDoc kinds =
+ [annotate AKey (ppString "ret") <+>
+ ppString "["
+ <> mconcat (map ppString (intersperse ", " (unSList _ (slistZip kinds outNames))))
+ -- <> ppString "] ++ ["
+ -- <> mconcat (map ppString (intersperse ", " (unSList getConst mapouts)))
+ <> ppString "]"]
+ return (\kinds ->
+ vcat (toList bs1' ++
+ [hang 2 (annotate AKey (ppString "loop") <+> ppString widname <+> annotate AKey (ppString "to") <+> ppString (getConst (slistIdx names1 wid))
+ <> hardline <> lnest')] ++
+ toList bs2' ++
+ outDoc kinds)
+ ,sappend outNames mapouts)
+
+ decorate :: RedKind -> String -> String
+ decorate RKRet name = name
+ decorate RKBuild name = "#" ++ name
+ decorate RKSum name = "↓" ++ name
+
+ partition :: Partition BuildUp RedSum loopouts mapouts sumouts -> SVal loopouts -> (SList (Const RedKind) loopouts, SVal mapouts, SVal sumouts)
+ partition PNil SNil = (SNil, SNil, SNil)
+ partition (Part1 BuildUp{} part) (Const name `SCons` names) =
+ let (loopouts, mapouts, sumouts) = partition part names
+ in (Const RKBuild `SCons` loopouts, Const name `SCons` mapouts, sumouts)
+ partition (Part2 RedSum{} part) (Const name `SCons` names) =
+ let (loopouts, mapouts, sumouts) = partition part names
+ in (Const RKSum `SCons` loopouts, mapouts, Const name `SCons` sumouts)
+
+ goBindings :: SVal env -> Bindings Ex env bs -> M (Bag ADoc, SVal bs)
+ goBindings _ BTop = return (mempty, SNil)
+ goBindings env (bs `BPush` (_, e)) = do
+ (docs, names) <- goBindings env bs
+ name <- genName' "x"
+ e' <- ppExpr' 0 (sappend names env) e
+ let doc = ppString (name ++ " = ") <> e'
+ return (docs <> pure doc, Const name `SCons` names)
+
ppAcPrj :: SMTy a -> SAcPrj p a b -> String
ppAcPrj _ SAPHere = "."
ppAcPrj (SMTPair t _) (SAPFst prj) = "(" ++ ppAcPrj t prj ++ ",)"
diff --git a/src/CHAD/Fusion.hs b/src/CHAD/Fusion.hs
index 3358d30..f863944 100644
--- a/src/CHAD/Fusion.hs
+++ b/src/CHAD/Fusion.hs
@@ -11,26 +11,24 @@
{-# LANGUAGE TypeOperators #-}
module CHAD.Fusion where
-import Data.Dependent.Map (DMap)
-- import Data.Dependent.Map qualified as DMap
-import Data.Functor.Const
-import Data.Kind (Type)
import Data.Some
-import Numeric.Natural
import CHAD.AST
import CHAD.AST.Bindings
import CHAD.AST.Count
import CHAD.AST.Env
import CHAD.Data
+import CHAD.Fusion.AST
import CHAD.Lemmas
-- TODO:
--- A bunch of data types are defined here that should be able to express a
--- graph of loop nests. A graph is a straight-line program whose statements
--- are, in this case, loop nests. A loop nest corresponds to what fusion
--- normally calls a "cluster", but is here represented as, well, a loop nest.
+-- A bunch of data types are defined here (in CHAD.Fusion.AST) that should be
+-- able to express a graph of loop nests. A graph is a straight-line program
+-- whose statements are, in this case, loop nests. A loop nest corresponds to
+-- what fusion normally calls a "cluster", but is here represented as, well, a
+-- loop nest.
--
-- No unzipping is done here, as I don't think it is necessary: I haven't been
-- able to think of programs that get more fusion opportunities when unzipped
@@ -57,72 +55,10 @@ import CHAD.Lemmas
-- and compile that to an actual C loop nest.
-- 6. Extend to other cool operations like EFold1InnerD1
+-- :m *CHAD.Fusion CHAD.AST.Pretty CHAD.Language
+-- pprintExpr $ fromNamed $ body $ build (SS (SS SZ)) (pair (pair nil 3) 4) (#idx :-> snd_ #idx + snd_ (fst_ #idx))
+-- putStrLn $ case fromNamed $ body $ build (SS (SS SZ)) (pair (pair nil 3) 4) (#idx :-> snd_ #idx + snd_ (fst_ #idx)) of EBuild _ n esh ebody -> let env = knownEnv in buildLoopNest env n esh ebody $ \sub nest -> show sub ++ "\n" ++ ppLoopNest (subList env sub) nest
-type FEx = Expr FGraph (Const ())
-
-type FGraph :: (Ty -> Type) -> [Ty] -> Ty -> Type
-data FGraph x env t where
- FGraph :: DMap NodeId (Node env) -> Tuple NodeId t -> FGraph (Const ()) env t
-
-data Node env t where
- FFreeVar :: STy t -> Idx env t -> Node env t
- FLoop :: SList NodeId args
- -> SList STy outs
- -> LoopNest args outs
- -> Tuple (Idx outs) t
- -> Node env t
-
-data NodeId t = NodeId Natural (STy t)
- deriving (Show)
-
-data Tuple f t where
- TupNil :: Tuple f TNil
- TupPair :: Tuple f a -> Tuple f b -> Tuple f (TPair a b)
- TupSingle :: f t -> Tuple f t
-deriving instance (forall a. Show (f a)) => Show (Tuple f t)
-
-data LoopNest args outs where
- Inner :: Bindings Ex args bs
- -> SList (Idx (Append bs args)) outs
- -> LoopNest args outs
- -- this should be able to express a simple nesting of builds and sums.
- Layer :: Bindings Ex args bs1
- -> Idx bs1 TIx -- ^ loop width (number of (parallel) iterations)
- -> LoopNest (TIx : Append bs1 args) loopouts
- -> Partition BuildUp RedSum loopouts mapouts sumouts
- -> Bindings Ex (Append sumouts (Append bs1 args)) bs2
- -> SList (Idx (Append bs2 (Append bs1 args))) outs
- -> LoopNest args (Append outs mapouts)
-deriving instance Show (LoopNest args outs)
-
-type Partition :: (Ty -> Ty -> Type) -> (Ty -> Ty -> Type) -> [Ty] -> [Ty] -> [Ty] -> Type
-data Partition f1 f2 ts ts1 ts2 where
- PNil :: Partition f1 f2 '[] '[] '[]
- Part1 :: f1 t t1 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) (t1 : ts1) ts2
- Part2 :: f2 t t2 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) ts1 (t2 : ts2)
-deriving instance (forall t t1. Show (f1 t t1), forall t t2. Show (f2 t t2)) => Show (Partition f1 f2 ts ts1 ts2)
-
-data BuildUp t t' where
- BuildUp :: SNat n -> STy t -> BuildUp (TArr n t) (TArr (S n) t)
-deriving instance Show (BuildUp t t')
-
-data RedSum t t' where
- RedSum :: SMTy t -> RedSum t t
-deriving instance Show (RedSum t t')
-
--- type family Unzip t where
--- Unzip (TPair a b) = TPair (Unzip a) (Unzip b)
--- Unzip (TArr n t) = UnzipA n t
-
--- type family UnzipA n t where
--- UnzipA n (TPair a b) = TPair (UnzipA n a) (UnzipA n b)
--- UnzipA n t = TArr n t
-
--- data Zipping ut t where
--- ZId :: Zipping t t
--- ZPair :: Zipping ua a -> Zipping ub b -> Zipping (TPair ua ub) (TPair a b)
--- ZZip :: Zipping ua (TArr n a) -> Zipping ub (TArr n b) -> Zipping (TPair ua ub) (TArr n (TPair a b))
--- deriving instance Show (Zipping ut t)
prependBinding :: forall args outs t. Ex args t -> LoopNest (t : args) outs -> LoopNest args outs
prependBinding e (Inner (bs :: Bindings Ex (t : args) bs) outs)
diff --git a/src/CHAD/Fusion/AST.hs b/src/CHAD/Fusion/AST.hs
new file mode 100644
index 0000000..3cd188a
--- /dev/null
+++ b/src/CHAD/Fusion/AST.hs
@@ -0,0 +1,83 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.Fusion.AST where
+
+import Data.Dependent.Map (DMap)
+import Data.Functor.Const
+import Data.Kind (Type)
+import Numeric.Natural
+
+import CHAD.AST
+import CHAD.AST.Bindings
+import CHAD.Data
+
+
+type FEx = Expr FGraph (Const ())
+
+type FGraph :: (Ty -> Type) -> [Ty] -> Ty -> Type
+data FGraph x env t where
+ FGraph :: DMap NodeId (Node env) -> Tuple NodeId t -> FGraph (Const ()) env t
+
+data Node env t where
+ FFreeVar :: STy t -> Idx env t -> Node env t
+ FLoop :: SList NodeId args
+ -> SList STy outs
+ -> LoopNest args outs
+ -> Tuple (Idx outs) t
+ -> Node env t
+
+data NodeId t = NodeId Natural (STy t)
+ deriving (Show)
+
+data Tuple f t where
+ TupNil :: Tuple f TNil
+ TupPair :: Tuple f a -> Tuple f b -> Tuple f (TPair a b)
+ TupSingle :: f t -> Tuple f t
+deriving instance (forall a. Show (f a)) => Show (Tuple f t)
+
+data LoopNest args outs where
+ Inner :: Bindings Ex args bs
+ -> SList (Idx (Append bs args)) outs
+ -> LoopNest args outs
+ -- this should be able to express a simple nesting of builds and sums.
+ Layer :: Bindings Ex args bs1
+ -> Idx bs1 TIx -- ^ loop width (number of (parallel) iterations)
+ -> LoopNest (TIx : Append bs1 args) loopouts
+ -> Partition BuildUp RedSum loopouts mapouts sumouts
+ -> Bindings Ex (Append sumouts (Append bs1 args)) bs2
+ -> SList (Idx (Append bs2 (Append bs1 args))) outs
+ -> LoopNest args (Append outs mapouts)
+deriving instance Show (LoopNest args outs)
+
+type Partition :: (Ty -> Ty -> Type) -> (Ty -> Ty -> Type) -> [Ty] -> [Ty] -> [Ty] -> Type
+data Partition f1 f2 ts ts1 ts2 where
+ PNil :: Partition f1 f2 '[] '[] '[]
+ Part1 :: f1 t t1 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) (t1 : ts1) ts2
+ Part2 :: f2 t t2 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) ts1 (t2 : ts2)
+deriving instance (forall t t1. Show (f1 t t1), forall t t2. Show (f2 t t2)) => Show (Partition f1 f2 ts ts1 ts2)
+
+data BuildUp t t' where
+ BuildUp :: SNat n -> STy t -> BuildUp (TArr n t) (TArr (S n) t)
+deriving instance Show (BuildUp t t')
+
+data RedSum t t' where
+ RedSum :: SMTy t -> RedSum t t
+deriving instance Show (RedSum t t')
+
+-- type family Unzip t where
+-- Unzip (TPair a b) = TPair (Unzip a) (Unzip b)
+-- Unzip (TArr n t) = UnzipA n t
+
+-- type family UnzipA n t where
+-- UnzipA n (TPair a b) = TPair (UnzipA n a) (UnzipA n b)
+-- UnzipA n t = TArr n t
+
+-- data Zipping ut t where
+-- ZId :: Zipping t t
+-- ZPair :: Zipping ua a -> Zipping ub b -> Zipping (TPair ua ub) (TPair a b)
+-- ZZip :: Zipping ua (TArr n a) -> Zipping ub (TArr n b) -> Zipping (TPair ua ub) (TArr n (TPair a b))
+-- deriving instance Show (Zipping ut t)