summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-27 23:49:24 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-27 23:49:24 +0100
commita5135c901d7fec098c5e105db1a03d63876508ff (patch)
treec08d71be92ae90ff9ada5ee93ca294a682ae8605
parentcc5ec97d19d998926669d0b86bef1fb4e3da3030 (diff)
Compile: vectorise commutative folds
-rw-r--r--src/Compile.hs37
1 files changed, 28 insertions, 9 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index 908b304..2c88a08 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -682,6 +682,10 @@ compile' env = \case
EFold1Inner _ commut efun ex0 earr -> do
let STArr (SS n) t = typeOf earr
+
+ let vecwid = case commut of Commut -> 8 :: Int
+ Noncommut -> 1
+
x0name <- compileAssign "foldx0" env ex0
arrname <- compileAssign "foldarr" env earr
@@ -703,22 +707,37 @@ compile' env = \case
ivar <- genName' "i"
jvar <- genName' "j"
+ kvar <- if vecwid > 1 then genName' "k" else return ""
+
accvar <- genName' "tot"
- let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ jvar ++ "];"
+ let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++
+ (if vecwid > 1 then show vecwid ++ " * " ++ jvar ++ " + " ++ kvar else jvar) ++ "]"
(funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun
((), arreltIncrStmts) <- scope $ incrementVarAlways "foldelt" Increment t arreltlit
emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $
pure (SVarDecl False (repSTy t) accvar (CELit x0name))
<> x0incrStmts -- we're copying x0 here
- <> pure (SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
- -- The combination function will consume the array element
- -- and the accumulator. The accumulator is replaced by
- -- what comes out of the function anyway, so that's
- -- fine, but we do need to increment the array element.
- arreltIncrStmts
- <> funStmts
- <> pure (SAsg accvar funres))
+ <> (let body =
+ -- The combination function will consume the array element
+ -- and the accumulator. The accumulator is replaced by
+ -- what comes out of the function anyway, so that's
+ -- fine, but we do need to increment the array element.
+ arreltIncrStmts
+ <> funStmts
+ <> pure (SAsg accvar funres)
+ nchunks = CEBinop (CELit lenname) "/" (CELit (show vecwid))
+ in if vecwid > 1
+ then BList
+ [SLoop (repSTy tIx) jvar (CELit "0") nchunks $
+ pure $ SLoop (repSTy tIx) kvar (CELit "0") (CELit (show vecwid)) $
+ body
+ ,SBlock (BList
+ [SVarDecl True (repSTy tIx) jvar (CELit "0")
+ ,SLoop (repSTy tIx) kvar (CEBinop nchunks "*" (CELit (show vecwid))) (CELit lenname) $
+ body])]
+ else pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
+ body)
<> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar))
incrementVarAlways "foldx0" Decrement t x0name