aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cbits/arith.c62
-rw-r--r--ops/Data/Array/Strided/Arith/Internal.hs97
2 files changed, 105 insertions, 54 deletions
diff --git a/cbits/arith.c b/cbits/arith.c
index b574d54..3659f6c 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -326,7 +326,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
// Walk a orthotope-style strided array, except for the inner dimension. The
// body is run for every "inner vector".
// Provides idx, outlinidx, arrlinidx.
-#define TARRAY_WALK_NOINNER(again_label_name, rank, shape, strides, body) \
+#define TARRAY_WALK_NOINNER(again_label_name, rank, shape, strides, ...) \
do { \
i64 idx[(rank) /* - 1 */]; /* Note: [zero-length VLA] */ \
memset(idx, 0, ((rank) - 1) * sizeof(idx[0])); \
@@ -334,7 +334,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
i64 outlinidx = 0; \
again_label_name: \
{ \
- body \
+ __VA_ARGS__ \
} \
for (i64 dim = (rank) - 2; dim >= 0; dim--) { \
if (++idx[dim] < (shape)[dim]) { \
@@ -351,7 +351,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
// inner dimension. The arrays must have the same shape, but may have different
// strides. The body is run for every pair of "inner vectors".
// Provides idx, outlinidx, arrlinidx1, arrlinidx2.
-#define TARRAY_WALK2_NOINNER(again_label_name, rank, shape, strides1, strides2, body) \
+#define TARRAY_WALK2_NOINNER(again_label_name, rank, shape, strides1, strides2, ...) \
do { \
i64 idx[(rank) /* - 1 */]; /* Note: [zero-length VLA] */ \
memset(idx, 0, ((rank) - 1) * sizeof(idx[0])); \
@@ -359,7 +359,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
i64 outlinidx = 0; \
again_label_name: \
{ \
- body \
+ __VA_ARGS__ \
} \
for (i64 dim = (rank) - 2; dim >= 0; dim--) { \
if (++idx[dim] < (shape)[dim]) { \
@@ -514,45 +514,30 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
}); \
}
-#define DOTPROD_STRIDED_OP(typ) \
- typ oxarop_dotprod_ ## typ ## _strided(i64 length, i64 stride1, const typ *arr1, i64 stride2, const typ *arr2) { \
- if (length < MANUAL_VECT_WID) { \
- typ res = 0; \
- for (i64 i = 0; i < length; i++) res += arr1[stride1 * i] * arr2[stride2 * i]; \
- return res; \
- } else { \
- typ accum[MANUAL_VECT_WID]; \
- for (i64 j = 0; j < MANUAL_VECT_WID; j++) accum[j] = arr1[stride1 * j] * arr2[stride2 * j]; \
- for (i64 i = 1; i < length / MANUAL_VECT_WID; i++) \
- for (i64 j = 0; j < MANUAL_VECT_WID; j++) \
- accum[j] += arr1[stride1 * (MANUAL_VECT_WID * i + j)] * arr2[stride2 * (MANUAL_VECT_WID * i + j)]; \
- typ res = accum[0]; \
- for (i64 j = 1; j < MANUAL_VECT_WID; j++) res += accum[j]; \
- for (i64 i = length / MANUAL_VECT_WID * MANUAL_VECT_WID; i < length; i++) \
- res += arr1[stride1 * i] * arr2[stride2 * i]; \
- return res; \
- } \
- }
-
// Reduces along the innermost dimension.
// 'out' will be filled densely in linearisation order.
#define DOTPROD_INNER_OP(typ) \
void oxarop_dotprodinner_ ## typ(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *arr1, const i64 *strides2, const typ *arr2) { \
TIME_START(tm); \
- if (strides1[rank - 1] == 1 && strides2[rank - 1] == 1) { \
- TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \
- out[outlinidx] = oxarop_dotprod_ ## typ ## _strided(shape[rank - 1], 1, arr1 + arrlinidx1, 1, arr2 + arrlinidx2); \
- }); \
- } else if (strides1[rank - 1] == -1 && strides2[rank - 1] == -1) { \
- TARRAY_WALK2_NOINNER(again2, rank, shape, strides1, strides2, { \
- const i64 len = shape[rank - 1]; \
- out[outlinidx] = oxarop_dotprod_ ## typ ## _strided(len, 1, arr1 + arrlinidx1 - (len - 1), 1, arr2 + arrlinidx2 - (len - 1)); \
- }); \
- } else { \
- TARRAY_WALK2_NOINNER(again3, rank, shape, strides1, strides2, { \
- out[outlinidx] = oxarop_dotprod_ ## typ ## _strided(shape[rank - 1], strides1[rank - 1], arr1 + arrlinidx1, strides2[rank - 1], arr2 + arrlinidx2); \
- }); \
- } \
+ TARRAY_WALK2_NOINNER(again3, rank, shape, strides1, strides2, { \
+ const i64 length = shape[rank - 1], stride1 = strides1[rank - 1], stride2 = strides2[rank - 1]; \
+ if (length < MANUAL_VECT_WID) { \
+ typ res = 0; \
+ for (i64 i = 0; i < length; i++) res += arr1[arrlinidx1 + stride1 * i] * arr2[arrlinidx2 + stride2 * i]; \
+ out[outlinidx] = res; \
+ } else { \
+ typ accum[MANUAL_VECT_WID]; \
+ for (i64 j = 0; j < MANUAL_VECT_WID; j++) accum[j] = arr1[arrlinidx1 + stride1 * j] * arr2[arrlinidx2 + stride2 * j]; \
+ for (i64 i = 1; i < length / MANUAL_VECT_WID; i++) \
+ for (i64 j = 0; j < MANUAL_VECT_WID; j++) \
+ accum[j] += arr1[arrlinidx1 + stride1 * (MANUAL_VECT_WID * i + j)] * arr2[arrlinidx2 + stride2 * (MANUAL_VECT_WID * i + j)]; \
+ typ res = accum[0]; \
+ for (i64 j = 1; j < MANUAL_VECT_WID; j++) res += accum[j]; \
+ for (i64 i = length / MANUAL_VECT_WID * MANUAL_VECT_WID; i < length; i++) \
+ res += arr1[arrlinidx1 + stride1 * i] * arr2[arrlinidx2 + stride2 * i]; \
+ out[outlinidx] = res; \
+ } \
+ }); \
stats_record_binary(sbi_dotprod, rank, shape, strides1, strides2, TIME_END(tm)); \
}
@@ -774,7 +759,6 @@ enum redop_tag_t {
ENTRY_REDUCEFULL_OPS(typ) \
EXTREMUM_OP(min, <, typ) \
EXTREMUM_OP(max, >, typ) \
- DOTPROD_STRIDED_OP(typ) \
DOTPROD_INNER_OP(typ)
NUM_TYPES_XLIST
#undef X
diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs
index a74e43d..313d72f 100644
--- a/ops/Data/Array/Strided/Arith/Internal.hs
+++ b/ops/Data/Array/Strided/Arith/Internal.hs
@@ -18,7 +18,7 @@ import Control.Monad
import Data.Bifunctor (second)
import Data.Bits
import Data.Int
-import Data.List (sort)
+import Data.List (sort, zip4)
import Data.Proxy
import Data.Type.Equality
import qualified Data.Vector.Storable as VS
@@ -184,7 +184,7 @@ unreplicateStrides (Array sh strides offset vec) =
simplifyArray :: Array n a
-> (forall n'. KnownNat n'
- => Array n' a -- U
+ => Array n' a -- U
-- Product of sizes of the unreplicated dimensions
-> Int
-- Convert index in U back to index into original
@@ -218,6 +218,64 @@ simplifyArray array k
| otherwise ->
arrayRevDims (init revDims) (Array (init (arrShape array)) (init (rereplicate (strides' ++ [0]))) offset' vec'))
+-- | The two input arrays must have the same shape.
+simplifyArray2 :: Array n a -> Array n a
+ -> (forall n'. KnownNat n'
+ => Array n' a -- U1
+ -> Array n' a -- U2 (same shape as U1)
+ -- Product of sizes of the dimensions that are
+ -- replicated in neither input
+ -> Int
+ -- Convert index in U{1,2} back to index into original
+ -- arrays. Dimensions that are replicated in both
+ -- inputs get 0.
+ -> ([Int] -> [Int])
+ -- Given a new array of the same shape as U1 (& U2),
+ -- convert it back to the original shape and
+ -- iteration order.
+ -> (Array n' a -> Array n a)
+ -- Do the same except without the INNER dimension.
+ -- This throws an error if the inner dimension had
+ -- stride 0 in both inputs.
+ -> (Array (n' - 1) a -> Array (n - 1) a)
+ -> r)
+ -> r
+simplifyArray2 arr1@(Array sh _ _ _) arr2@(Array sh2 _ _ _) k
+ | sh /= sh2 = error "simplifyArray2: Unequal shapes"
+
+ | let revDims = zipWith (\s1 s2 -> s1 < 0 && s2 < 0) (arrStrides arr1) (arrStrides arr2)
+ , Array _ strides1 offset1 vec1 <- arrayRevDims revDims arr1
+ , Array _ strides2 offset2 vec2 <- arrayRevDims revDims arr2
+
+ , let replDims = zipWith (\s1 s2 -> s1 == 0 && s2 == 0) strides1 strides2
+ , let (shF, strides1F, strides2F) = unzip3 [(n, s1, s2) | (n, s1, s2, False) <- zip4 sh strides1 strides2 replDims]
+
+ , let reinsertZeros (False : zeros) (s : strides') = s : reinsertZeros zeros strides'
+ reinsertZeros (True : zeros) strides' = 0 : reinsertZeros zeros strides'
+ reinsertZeros [] [] = []
+ reinsertZeros (False : _) [] = error $ "simplifyArray2: Internal error: reply strides too short"
+ reinsertZeros [] (_:_) = error $ "simplifyArray2: Internal error: reply strides too long"
+
+ , let unrepSize = product [n | (n, True) <- zip sh replDims]
+
+ = TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
+ k @lenshF
+ (Array shF strides1F offset1 vec1)
+ (Array shF strides2F offset2 vec2)
+ unrepSize
+ (\idx -> zipWith3 (\b n i -> if b then n - 1 - i else i)
+ revDims sh (reinsertZeros replDims idx))
+ (\(Array sh' strides' offset' vec') ->
+ if sh' /= shF then error $ "simplifyArray2: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show shF ++ ")"
+ else arrayRevDims revDims (Array sh (reinsertZeros replDims strides') offset' vec'))
+ (\(Array sh' strides' offset' vec') ->
+ if | sh' /= init shF ->
+ error $ "simplifyArray2: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show shF ++ ")"
+ | last replDims ->
+ error $ "simplifyArray2: Internal error: reduction reply handler used while inner dimension was unreplicated"
+ | otherwise ->
+ arrayRevDims (init revDims) (Array (init sh) (reinsertZeros (init replDims) strides') offset' vec'))
+
{-# NOINLINE wrapUnary #-}
wrapUnary :: forall a b n. Storable a
=> SNat n
@@ -418,19 +476,28 @@ vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner
(vectorRedInnerOp sn valconv ptrconv fscale fred arr1)
(Array (init sh2) (init strides2) offset2 vec2)
-- now there is useful dotprod work along the inner dimension
- | otherwise = unsafePerformIO $ do
- let inrank = fromSNat' sn + 1
- outv <- VSM.unsafeNew (product (init sh1))
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith (VS.fromListN inrank (map fromIntegral sh1)) $ \psh ->
- VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides1)) $ \pstrides1 ->
- VS.unsafeWith vec1 $ \pvec1 ->
- VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides2)) $ \pstrides2 ->
- VS.unsafeWith vec2 $ \pvec2 ->
- fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv)
- pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1))
- pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2))
- arrayFromVector @_ @n (init sh1) <$> VS.unsafeFreeze outv
+ | otherwise =
+ simplifyArray2 arr1 arr2 $ \(Array sh' strides1' offset1' vec1' :: Array n' a) (Array _ strides2' offset2' vec2') _ _ _ restore ->
+ unsafePerformIO $ do
+ let inrank = length sh'
+ outv <- VSM.unsafeNew (product (init sh'))
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN inrank (map fromIntegral sh')) $ \psh ->
+ VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides1')) $ \pstrides1 ->
+ VS.unsafeWith vec1' $ \pvec1 ->
+ VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides2')) $ \pstrides2 ->
+ VS.unsafeWith vec2' $ \pvec2 ->
+ fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv)
+ pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1'))
+ pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2'))
+ TypeNats.withSomeSNat (fromIntegral (inrank - 1)) $ \(SNat :: SNat n'm1) -> do
+ (Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of
+ LTI -> pure Dict
+ EQI -> pure Dict
+ GTI -> error "impossible" -- because `last strides1 /= 0`
+ case sameNat (natSing @(n' - 1)) (natSing @n'm1) of
+ Just Refl -> restore . arrayFromVector (init sh') <$> VS.unsafeFreeze outv
+ Nothing -> error "impossible"
mulWithInt :: Num a => a -> Int -> a
mulWithInt a i = a * fromIntegral i