diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-28 00:08:26 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-28 00:08:26 +0100 |
commit | a3ba3bdc5c2f9606a0b98cdf53183841cca07eac (patch) | |
tree | 1f20b4459727eaf78369cfa12226bd9fe4affeae /src/Compile.hs | |
parent | a5135c901d7fec098c5e105db1a03d63876508ff (diff) |
Compile: Commutative fold still broken, but sum is vectorised
Diffstat (limited to 'src/Compile.hs')
-rw-r--r-- | src/Compile.hs | 52 |
1 files changed, 24 insertions, 28 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index 2c88a08..b4261ca 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -683,8 +683,8 @@ compile' env = \case EFold1Inner _ commut efun ex0 earr -> do let STArr (SS n) t = typeOf earr - let vecwid = case commut of Commut -> 8 :: Int - Noncommut -> 1 + -- let vecwid = case commut of Commut -> 8 :: Int + -- Noncommut -> 1 x0name <- compileAssign "foldx0" env ex0 arrname <- compileAssign "foldarr" env earr @@ -707,37 +707,25 @@ compile' env = \case ivar <- genName' "i" jvar <- genName' "j" - kvar <- if vecwid > 1 then genName' "k" else return "" + -- kvar <- if vecwid > 1 then genName' "k" else return "" accvar <- genName' "tot" let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ - (if vecwid > 1 then show vecwid ++ " * " ++ jvar ++ " + " ++ kvar else jvar) ++ "]" + ({- if vecwid > 1 then show vecwid ++ " * " ++ jvar ++ " + " ++ kvar else -} jvar) ++ "]" (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldelt" Increment t arreltlit emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ pure (SVarDecl False (repSTy t) accvar (CELit x0name)) <> x0incrStmts -- we're copying x0 here - <> (let body = - -- The combination function will consume the array element - -- and the accumulator. The accumulator is replaced by - -- what comes out of the function anyway, so that's - -- fine, but we do need to increment the array element. - arreltIncrStmts - <> funStmts - <> pure (SAsg accvar funres) - nchunks = CEBinop (CELit lenname) "/" (CELit (show vecwid)) - in if vecwid > 1 - then BList - [SLoop (repSTy tIx) jvar (CELit "0") nchunks $ - pure $ SLoop (repSTy tIx) kvar (CELit "0") (CELit (show vecwid)) $ - body - ,SBlock (BList - [SVarDecl True (repSTy tIx) jvar (CELit "0") - ,SLoop (repSTy tIx) kvar (CEBinop nchunks "*" (CELit (show vecwid))) (CELit lenname) $ - body])] - else pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ - body) + <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the array element + -- and the accumulator. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the array element. + arreltIncrStmts + <> funStmts + <> pure (SAsg accvar funres)) <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) incrementVarAlways "foldx0" Decrement t x0name @@ -762,15 +750,23 @@ compile' env = \case emit $ SVarDecl True (repSTy tIx) lenname (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) + let vecwid = 8 :: Int ivar <- genName' "i" jvar <- genName' "j" + kvar <- genName' "k" accvar <- genName' "tot" + let nchunks = CEBinop (CELit lenname) "/" (CELit (show vecwid)) emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList -- we have ScalIsNumeric, so it has 0 and (+) in C - [SVarDecl False (repSTy t) accvar (CELit "0") - ,SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ - pure $ SVerbatim $ accvar ++ " += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ jvar ++ "];" - ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)] + [SVerbatim $ repSTy t ++ " " ++ accvar ++ "[" ++ show vecwid ++ "] = {" ++ intercalate "," (replicate vecwid "0") ++ "};" + ,SLoop (repSTy tIx) jvar (CELit "0") nchunks $ + pure $ SLoop (repSTy tIx) kvar (CELit "0") (CELit (show vecwid)) $ + pure $ SVerbatim $ accvar ++ "[" ++ kvar ++ "] += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ show vecwid ++ " * " ++ jvar ++ " + " ++ kvar ++ "];" + ,SLoop (repSTy tIx) kvar (CELit "1") (CELit (show vecwid)) $ + pure $ SVerbatim $ accvar ++ "[0] += " ++ accvar ++ "[" ++ kvar ++ "];" + ,SLoop (repSTy tIx) kvar (CEBinop nchunks "*" (CELit (show vecwid))) (CELit lenname) $ + pure $ SVerbatim $ accvar ++ "[0] += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ kvar ++ "];" + ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit (accvar++"[0]"))] incrementVarAlways "sum" Decrement (typeOf e) argname |