diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-27 23:49:24 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-27 23:49:24 +0100 |
commit | a5135c901d7fec098c5e105db1a03d63876508ff (patch) | |
tree | c08d71be92ae90ff9ada5ee93ca294a682ae8605 | |
parent | cc5ec97d19d998926669d0b86bef1fb4e3da3030 (diff) |
Compile: vectorise commutative folds
-rw-r--r-- | src/Compile.hs | 37 |
1 files changed, 28 insertions, 9 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index 908b304..2c88a08 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -682,6 +682,10 @@ compile' env = \case EFold1Inner _ commut efun ex0 earr -> do let STArr (SS n) t = typeOf earr + + let vecwid = case commut of Commut -> 8 :: Int + Noncommut -> 1 + x0name <- compileAssign "foldx0" env ex0 arrname <- compileAssign "foldarr" env earr @@ -703,22 +707,37 @@ compile' env = \case ivar <- genName' "i" jvar <- genName' "j" + kvar <- if vecwid > 1 then genName' "k" else return "" + accvar <- genName' "tot" - let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ jvar ++ "];" + let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ + (if vecwid > 1 then show vecwid ++ " * " ++ jvar ++ " + " ++ kvar else jvar) ++ "]" (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldelt" Increment t arreltlit emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ pure (SVarDecl False (repSTy t) accvar (CELit x0name)) <> x0incrStmts -- we're copying x0 here - <> pure (SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ - -- The combination function will consume the array element - -- and the accumulator. The accumulator is replaced by - -- what comes out of the function anyway, so that's - -- fine, but we do need to increment the array element. - arreltIncrStmts - <> funStmts - <> pure (SAsg accvar funres)) + <> (let body = + -- The combination function will consume the array element + -- and the accumulator. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the array element. + arreltIncrStmts + <> funStmts + <> pure (SAsg accvar funres) + nchunks = CEBinop (CELit lenname) "/" (CELit (show vecwid)) + in if vecwid > 1 + then BList + [SLoop (repSTy tIx) jvar (CELit "0") nchunks $ + pure $ SLoop (repSTy tIx) kvar (CELit "0") (CELit (show vecwid)) $ + body + ,SBlock (BList + [SVarDecl True (repSTy tIx) jvar (CELit "0") + ,SLoop (repSTy tIx) kvar (CEBinop nchunks "*" (CELit (show vecwid))) (CELit lenname) $ + body])] + else pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + body) <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) incrementVarAlways "foldx0" Decrement t x0name |