summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-28 00:08:26 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-28 00:08:26 +0100
commita3ba3bdc5c2f9606a0b98cdf53183841cca07eac (patch)
tree1f20b4459727eaf78369cfa12226bd9fe4affeae
parenta5135c901d7fec098c5e105db1a03d63876508ff (diff)
Compile: Commutative fold still broken, but sum is vectorised
-rw-r--r--src/Compile.hs52
1 files changed, 24 insertions, 28 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index 2c88a08..b4261ca 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -683,8 +683,8 @@ compile' env = \case
EFold1Inner _ commut efun ex0 earr -> do
let STArr (SS n) t = typeOf earr
- let vecwid = case commut of Commut -> 8 :: Int
- Noncommut -> 1
+ -- let vecwid = case commut of Commut -> 8 :: Int
+ -- Noncommut -> 1
x0name <- compileAssign "foldx0" env ex0
arrname <- compileAssign "foldarr" env earr
@@ -707,37 +707,25 @@ compile' env = \case
ivar <- genName' "i"
jvar <- genName' "j"
- kvar <- if vecwid > 1 then genName' "k" else return ""
+ -- kvar <- if vecwid > 1 then genName' "k" else return ""
accvar <- genName' "tot"
let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++
- (if vecwid > 1 then show vecwid ++ " * " ++ jvar ++ " + " ++ kvar else jvar) ++ "]"
+ ({- if vecwid > 1 then show vecwid ++ " * " ++ jvar ++ " + " ++ kvar else -} jvar) ++ "]"
(funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun
((), arreltIncrStmts) <- scope $ incrementVarAlways "foldelt" Increment t arreltlit
emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $
pure (SVarDecl False (repSTy t) accvar (CELit x0name))
<> x0incrStmts -- we're copying x0 here
- <> (let body =
- -- The combination function will consume the array element
- -- and the accumulator. The accumulator is replaced by
- -- what comes out of the function anyway, so that's
- -- fine, but we do need to increment the array element.
- arreltIncrStmts
- <> funStmts
- <> pure (SAsg accvar funres)
- nchunks = CEBinop (CELit lenname) "/" (CELit (show vecwid))
- in if vecwid > 1
- then BList
- [SLoop (repSTy tIx) jvar (CELit "0") nchunks $
- pure $ SLoop (repSTy tIx) kvar (CELit "0") (CELit (show vecwid)) $
- body
- ,SBlock (BList
- [SVarDecl True (repSTy tIx) jvar (CELit "0")
- ,SLoop (repSTy tIx) kvar (CEBinop nchunks "*" (CELit (show vecwid))) (CELit lenname) $
- body])]
- else pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
- body)
+ <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
+ -- The combination function will consume the array element
+ -- and the accumulator. The accumulator is replaced by
+ -- what comes out of the function anyway, so that's
+ -- fine, but we do need to increment the array element.
+ arreltIncrStmts
+ <> funStmts
+ <> pure (SAsg accvar funres))
<> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar))
incrementVarAlways "foldx0" Decrement t x0name
@@ -762,15 +750,23 @@ compile' env = \case
emit $ SVarDecl True (repSTy tIx) lenname
(CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))
+ let vecwid = 8 :: Int
ivar <- genName' "i"
jvar <- genName' "j"
+ kvar <- genName' "k"
accvar <- genName' "tot"
+ let nchunks = CEBinop (CELit lenname) "/" (CELit (show vecwid))
emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList
-- we have ScalIsNumeric, so it has 0 and (+) in C
- [SVarDecl False (repSTy t) accvar (CELit "0")
- ,SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
- pure $ SVerbatim $ accvar ++ " += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ jvar ++ "];"
- ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)]
+ [SVerbatim $ repSTy t ++ " " ++ accvar ++ "[" ++ show vecwid ++ "] = {" ++ intercalate "," (replicate vecwid "0") ++ "};"
+ ,SLoop (repSTy tIx) jvar (CELit "0") nchunks $
+ pure $ SLoop (repSTy tIx) kvar (CELit "0") (CELit (show vecwid)) $
+ pure $ SVerbatim $ accvar ++ "[" ++ kvar ++ "] += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ show vecwid ++ " * " ++ jvar ++ " + " ++ kvar ++ "];"
+ ,SLoop (repSTy tIx) kvar (CELit "1") (CELit (show vecwid)) $
+ pure $ SVerbatim $ accvar ++ "[0] += " ++ accvar ++ "[" ++ kvar ++ "];"
+ ,SLoop (repSTy tIx) kvar (CEBinop nchunks "*" (CELit (show vecwid))) (CELit lenname) $
+ pure $ SVerbatim $ accvar ++ "[0] += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ kvar ++ "];"
+ ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit (accvar++"[0]"))]
incrementVarAlways "sum" Decrement (typeOf e) argname