diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-10-30 20:09:10 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-10-30 20:09:10 +0100 |
| commit | 654b13d0de961788ed600e8eeb6c9fbbd736439e (patch) | |
| tree | 4821c1082aa2fcd04ea6ecb0e07f8d3bdbe643ed /src/Compile.hs | |
| parent | 4c9ae47dd5bbd27b1acb6dc5d4a55657ac1f026f (diff) | |
Diffstat (limited to 'src/Compile.hs')
| -rw-r--r-- | src/Compile.hs | 114 |
1 files changed, 113 insertions, 1 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index 4e81c6a..f2063ee 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -34,6 +34,7 @@ import Foreign import GHC.Exts (int2Word#, addr2Int#) import GHC.Num (integerFromWord#) import GHC.Ptr (Ptr(..)) +import GHC.Stack (HasCallStack) import Numeric (showHex) import System.IO (hPutStrLn, stderr) import System.IO.Error (mkIOError, userErrorType) @@ -939,6 +940,117 @@ compile' env = \case [("buf", CEProj (CELit arrname) "buf") ,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))]) + EFold1InnerD1 _ commut efun ex0 earr -> do + let STArr (SS n) t = typeOf earr + STPair _ bty = typeOf efun + + x0name <- compileAssign "foldd1x0" env ex0 + arrname <- compileAssign "foldd1arr" env earr + + zeroRefcountCheck (typeOf earr) "fold1iD1" arrname + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]")) + + shsz1name <- genName' "shszN" + emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n arrname) -- take init of arr's shape + shsz2name <- genName' "shszSN" + emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname)) + + resname <- allocArray "foldd1" Malloc "foldd1res" n t (Just (CELit shsz1name)) (compileArrShapeComponents n arrname) + storesname <- allocArray "foldd1" Malloc "foldd1stores" (SS n) bty (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) arrname) + + ((), x0incrStmts) <- scope $ incrementVarAlways "foldd1x0" Increment t x0name + + ivar <- genName' "i" + jvar <- genName' "j" + + accvar <- genName' "tot" + let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ jvar + arreltlit = arrname ++ ".buf->xs[" ++ eltidx ++ "]" + (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun + funresvar <- genName' "res" + ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldd1elt" Increment t arreltlit + + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ + pure (SVarDecl False (repSTy t) accvar (CELit x0name)) + <> x0incrStmts -- we're copying x0 here + <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the array element + -- and the accumulator. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the array element. + arreltIncrStmts + <> funStmts + <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) + <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) + <> pure (SAsg (storesname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) + <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) + + incrementVarAlways "foldd1x0" Decrement t x0name + incrementVarAlways "foldd1arr" Decrement (typeOf earr) arrname + + strname <- emitStruct (STPair (STArr n t) (STArr (SS n) bty)) + return (CEStruct strname [("a", CELit resname), ("b", CELit storesname)]) + + EFold1InnerD2 _ commut efun estores ectg -> do + let STArr n t2 = typeOf ectg + STArr _ bty = typeOf estores + + storesname <- compileAssign "foldd2stores" env estores + ctgname <- compileAssign "foldd2ctg" env ectg + + zeroRefcountCheck (typeOf ectg) "fold1iD2" ctgname + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (storesname ++ ".sh[" ++ show (fromSNat n) ++ "]")) + + shsz1name <- genName' "shszN" + emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n storesname) -- take init of the shape + shsz2name <- genName' "shszSN" + emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname)) + + x0ctgname <- allocArray "foldd2" Malloc "foldd2x0ctg" n t2 (Just (CELit shsz1name)) (compileArrShapeComponents n storesname) + outctgname <- allocArray "foldd2" Malloc "foldd2outctg" (SS n) t2 (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) storesname) + + ivar <- genName' "i" + jvar <- genName' "j" + + accvar <- genName' "acc" + let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ lenname ++ "-1 - " ++ jvar + storeseltlit = storesname ++ ".buf->xs[" ++ eltidx ++ "]" + ctgeltlit = ctgname ++ ".buf->xs[" ++ ivar ++ "]" + (funres, funStmts) <- scope $ compile' (Const accvar `SCons` Const storeseltlit `SCons` env) efun + funresvar <- genName' "res" + ((), storeseltIncrStmts) <- scope $ incrementVarAlways "foldd2selt" Increment bty storeseltlit + ((), ctgeltIncrStmts) <- scope $ incrementVarAlways "foldd2celt" Increment bty ctgeltlit + + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ + pure (SVarDecl False (repSTy t2) accvar (CELit ctgeltlit)) + <> ctgeltIncrStmts + -- we need to loop in reverse here, but we let jvar run in the + -- forward direction so that we can use SLoop. Note jvar is + -- reversed in eltidx above + <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the accumulator + -- and the stores element. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the stores element. + storeseltIncrStmts + <> funStmts + <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) + <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) + <> pure (SAsg (outctgname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) + <> pure (SAsg (x0ctgname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) + + incrementVarAlways "foldd2stores" Decrement (STArr (SS n) bty) storesname + incrementVarAlways "foldd2ctg" Decrement (STArr n t2) ctgname + + strname <- emitStruct (STPair (STArr n t2) (STArr (SS n) t2)) + return (CEStruct strname [("a", CELit x0ctgname), ("b", CELit outctgname)]) + EConst _ t x -> return $ CELit $ compileScal True t x EIdx0 _ e -> do @@ -1311,7 +1423,7 @@ data AllocMethod = Malloc | Calloc deriving (Show) -- | The shape must have the outer dimension at the head (and the inner dimension on the right). -allocArray :: String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String +allocArray :: HasCallStack => String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String allocArray marker method nameBase rank eltty mshsz shape = do when (length shape /= fromSNat rank) $ error "allocArray: shape does not match rank" |
