diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-03-27 23:49:24 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-27 23:49:24 +0100 | 
| commit | a5135c901d7fec098c5e105db1a03d63876508ff (patch) | |
| tree | c08d71be92ae90ff9ada5ee93ca294a682ae8605 | |
| parent | cc5ec97d19d998926669d0b86bef1fb4e3da3030 (diff) | |
Compile: vectorise commutative folds
| -rw-r--r-- | src/Compile.hs | 37 | 
1 files changed, 28 insertions, 9 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index 908b304..2c88a08 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -682,6 +682,10 @@ compile' env = \case    EFold1Inner _ commut efun ex0 earr -> do      let STArr (SS n) t = typeOf earr + +    let vecwid = case commut of Commut -> 8 :: Int +                                Noncommut -> 1 +      x0name <- compileAssign "foldx0" env ex0      arrname <- compileAssign "foldarr" env earr @@ -703,22 +707,37 @@ compile' env = \case      ivar <- genName' "i"      jvar <- genName' "j" +    kvar <- if vecwid > 1 then genName' "k" else return "" +      accvar <- genName' "tot" -    let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ jvar ++ "];" +    let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ +                    (if vecwid > 1 then show vecwid ++ " * " ++ jvar ++ " + " ++ kvar else jvar) ++ "]"      (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun      ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldelt" Increment t arreltlit      emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $               pure (SVarDecl False (repSTy t) accvar (CELit x0name))               <> x0incrStmts  -- we're copying x0 here -             <> pure (SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ -                        -- The combination function will consume the array element -                        -- and the accumulator. The accumulator is replaced by -                        -- what comes out of the function anyway, so that's -                        -- fine, but we do need to increment the array element. -                        arreltIncrStmts -                        <> funStmts -                        <> pure (SAsg accvar funres)) +             <> (let body = +                       -- The combination function will consume the array element +                       -- and the accumulator. The accumulator is replaced by +                       -- what comes out of the function anyway, so that's +                       -- fine, but we do need to increment the array element. +                       arreltIncrStmts +                       <> funStmts +                       <> pure (SAsg accvar funres) +                     nchunks = CEBinop (CELit lenname) "/" (CELit (show vecwid)) +                 in if vecwid > 1 +                      then BList +                             [SLoop (repSTy tIx) jvar (CELit "0") nchunks $ +                                pure $ SLoop (repSTy tIx) kvar (CELit "0") (CELit (show vecwid)) $ +                                  body +                             ,SBlock (BList +                                [SVarDecl True (repSTy tIx) jvar (CELit "0") +                                ,SLoop (repSTy tIx) kvar (CEBinop nchunks "*" (CELit (show vecwid))) (CELit lenname) $ +                                   body])] +                      else pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ +                             body)               <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar))      incrementVarAlways "foldx0" Decrement t x0name  | 
