diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-03-25 17:09:20 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-25 17:09:20 +0100 | 
| commit | a78ddeaa5d34fa8b6fa52eee42977cc46e8c36a5 (patch) | |
| tree | 49bb90253cf3af73b2c27042c8ca98c98f05220b | |
| parent | 575a218d1b23b454fcdcf2b6ad0018fdc32b64b6 (diff) | |
Dotprod: Optimise reversed and replicated dimensions
| -rw-r--r-- | cbits/arith.c | 62 | ||||
| -rw-r--r-- | ops/Data/Array/Strided/Arith/Internal.hs | 97 | 
2 files changed, 105 insertions, 54 deletions
diff --git a/cbits/arith.c b/cbits/arith.c index b574d54..3659f6c 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -326,7 +326,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {  // Walk a orthotope-style strided array, except for the inner dimension. The  // body is run for every "inner vector".  // Provides idx, outlinidx, arrlinidx. -#define TARRAY_WALK_NOINNER(again_label_name, rank, shape, strides, body) \ +#define TARRAY_WALK_NOINNER(again_label_name, rank, shape, strides, ...) \    do { \      i64 idx[(rank) /* - 1 */]; /* Note: [zero-length VLA] */ \      memset(idx, 0, ((rank) - 1) * sizeof(idx[0])); \ @@ -334,7 +334,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {      i64 outlinidx = 0; \    again_label_name: \      { \ -      body \ +      __VA_ARGS__ \      } \      for (i64 dim = (rank) - 2; dim >= 0; dim--) { \        if (++idx[dim] < (shape)[dim]) { \ @@ -351,7 +351,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {  // inner dimension. The arrays must have the same shape, but may have different  // strides. The body is run for every pair of "inner vectors".  // Provides idx, outlinidx, arrlinidx1, arrlinidx2. -#define TARRAY_WALK2_NOINNER(again_label_name, rank, shape, strides1, strides2, body) \ +#define TARRAY_WALK2_NOINNER(again_label_name, rank, shape, strides1, strides2, ...) \    do { \      i64 idx[(rank) /* - 1 */]; /* Note: [zero-length VLA] */ \      memset(idx, 0, ((rank) - 1) * sizeof(idx[0])); \ @@ -359,7 +359,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {      i64 outlinidx = 0; \    again_label_name: \      { \ -      body \ +      __VA_ARGS__ \      } \      for (i64 dim = (rank) - 2; dim >= 0; dim--) { \        if (++idx[dim] < (shape)[dim]) { \ @@ -514,45 +514,30 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {      }); \    } -#define DOTPROD_STRIDED_OP(typ) \ -  typ oxarop_dotprod_ ## typ ## _strided(i64 length, i64 stride1, const typ *arr1, i64 stride2, const typ *arr2) { \ -    if (length < MANUAL_VECT_WID) { \ -      typ res = 0; \ -      for (i64 i = 0; i < length; i++) res += arr1[stride1 * i] * arr2[stride2 * i]; \ -      return res; \ -    } else { \ -      typ accum[MANUAL_VECT_WID]; \ -      for (i64 j = 0; j < MANUAL_VECT_WID; j++) accum[j] = arr1[stride1 * j] * arr2[stride2 * j]; \ -      for (i64 i = 1; i < length / MANUAL_VECT_WID; i++) \ -        for (i64 j = 0; j < MANUAL_VECT_WID; j++) \ -          accum[j] += arr1[stride1 * (MANUAL_VECT_WID * i + j)] * arr2[stride2 * (MANUAL_VECT_WID * i + j)]; \ -      typ res = accum[0]; \ -      for (i64 j = 1; j < MANUAL_VECT_WID; j++) res += accum[j]; \ -      for (i64 i = length / MANUAL_VECT_WID * MANUAL_VECT_WID; i < length; i++) \ -        res += arr1[stride1 * i] * arr2[stride2 * i]; \ -      return res; \ -    } \ -  } -  // Reduces along the innermost dimension.  // 'out' will be filled densely in linearisation order.  #define DOTPROD_INNER_OP(typ) \    void oxarop_dotprodinner_ ## typ(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *arr1, const i64 *strides2, const typ *arr2) { \      TIME_START(tm); \ -    if (strides1[rank - 1] == 1 && strides2[rank - 1] == 1) { \ -      TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \ -        out[outlinidx] = oxarop_dotprod_ ## typ ## _strided(shape[rank - 1], 1, arr1 + arrlinidx1, 1, arr2 + arrlinidx2); \ -      }); \ -    } else if (strides1[rank - 1] == -1 && strides2[rank - 1] == -1) { \ -      TARRAY_WALK2_NOINNER(again2, rank, shape, strides1, strides2, { \ -        const i64 len = shape[rank - 1]; \ -        out[outlinidx] = oxarop_dotprod_ ## typ ## _strided(len, 1, arr1 + arrlinidx1 - (len - 1), 1, arr2 + arrlinidx2 - (len - 1)); \ -      }); \ -    } else { \ -      TARRAY_WALK2_NOINNER(again3, rank, shape, strides1, strides2, { \ -        out[outlinidx] = oxarop_dotprod_ ## typ ## _strided(shape[rank - 1], strides1[rank - 1], arr1 + arrlinidx1, strides2[rank - 1], arr2 + arrlinidx2); \ -      }); \ -    } \ +    TARRAY_WALK2_NOINNER(again3, rank, shape, strides1, strides2, { \ +      const i64 length = shape[rank - 1], stride1 = strides1[rank - 1], stride2 = strides2[rank - 1]; \ +      if (length < MANUAL_VECT_WID) { \ +        typ res = 0; \ +        for (i64 i = 0; i < length; i++) res += arr1[arrlinidx1 + stride1 * i] * arr2[arrlinidx2 + stride2 * i]; \ +        out[outlinidx] = res; \ +      } else { \ +        typ accum[MANUAL_VECT_WID]; \ +        for (i64 j = 0; j < MANUAL_VECT_WID; j++) accum[j] = arr1[arrlinidx1 + stride1 * j] * arr2[arrlinidx2 + stride2 * j]; \ +        for (i64 i = 1; i < length / MANUAL_VECT_WID; i++) \ +          for (i64 j = 0; j < MANUAL_VECT_WID; j++) \ +            accum[j] += arr1[arrlinidx1 + stride1 * (MANUAL_VECT_WID * i + j)] * arr2[arrlinidx2 + stride2 * (MANUAL_VECT_WID * i + j)]; \ +        typ res = accum[0]; \ +        for (i64 j = 1; j < MANUAL_VECT_WID; j++) res += accum[j]; \ +        for (i64 i = length / MANUAL_VECT_WID * MANUAL_VECT_WID; i < length; i++) \ +          res += arr1[arrlinidx1 + stride1 * i] * arr2[arrlinidx2 + stride2 * i]; \ +        out[outlinidx] = res; \ +      } \ +    }); \      stats_record_binary(sbi_dotprod, rank, shape, strides1, strides2, TIME_END(tm)); \    } @@ -774,7 +759,6 @@ enum redop_tag_t {    ENTRY_REDUCEFULL_OPS(typ) \    EXTREMUM_OP(min, <, typ) \    EXTREMUM_OP(max, >, typ) \ -  DOTPROD_STRIDED_OP(typ) \    DOTPROD_INNER_OP(typ)  NUM_TYPES_XLIST  #undef X diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs index a74e43d..313d72f 100644 --- a/ops/Data/Array/Strided/Arith/Internal.hs +++ b/ops/Data/Array/Strided/Arith/Internal.hs @@ -18,7 +18,7 @@ import Control.Monad  import Data.Bifunctor (second)  import Data.Bits  import Data.Int -import Data.List (sort) +import Data.List (sort, zip4)  import Data.Proxy  import Data.Type.Equality  import qualified Data.Vector.Storable as VS @@ -184,7 +184,7 @@ unreplicateStrides (Array sh strides offset vec) =  simplifyArray :: Array n a                -> (forall n'. KnownNat n' -              => Array n' a  -- U +                          => Array n' a  -- U                            -- Product of sizes of the unreplicated dimensions                            -> Int                            -- Convert index in U back to index into original @@ -218,6 +218,64 @@ simplifyArray array k              | otherwise ->                  arrayRevDims (init revDims) (Array (init (arrShape array)) (init (rereplicate (strides' ++ [0]))) offset' vec')) +-- | The two input arrays must have the same shape. +simplifyArray2 :: Array n a -> Array n a +               -> (forall n'. KnownNat n' +                           => Array n' a  -- U1 +                           -> Array n' a  -- U2 (same shape as U1) +                           -- Product of sizes of the dimensions that are +                           -- replicated in neither input +                           -> Int +                           -- Convert index in U{1,2} back to index into original +                           -- arrays. Dimensions that are replicated in both +                           -- inputs get 0. +                           -> ([Int] -> [Int]) +                           -- Given a new array of the same shape as U1 (& U2), +                           -- convert it back to the original shape and +                           -- iteration order. +                           -> (Array n' a -> Array n a) +                           -- Do the same except without the INNER dimension. +                           -- This throws an error if the inner dimension had +                           -- stride 0 in both inputs. +                           -> (Array (n' - 1) a -> Array (n - 1) a) +                           -> r) +               -> r +simplifyArray2 arr1@(Array sh _ _ _) arr2@(Array sh2 _ _ _) k +  | sh /= sh2 = error "simplifyArray2: Unequal shapes" + +  | let revDims = zipWith (\s1 s2 -> s1 < 0 && s2 < 0) (arrStrides arr1) (arrStrides arr2) +  , Array _ strides1 offset1 vec1 <- arrayRevDims revDims arr1 +  , Array _ strides2 offset2 vec2 <- arrayRevDims revDims arr2 + +  , let replDims = zipWith (\s1 s2 -> s1 == 0 && s2 == 0) strides1 strides2 +  , let (shF, strides1F, strides2F) = unzip3 [(n, s1, s2) | (n, s1, s2, False) <- zip4 sh strides1 strides2 replDims] + +  , let reinsertZeros (False : zeros) (s : strides') = s : reinsertZeros zeros strides' +        reinsertZeros (True : zeros) strides' = 0 : reinsertZeros zeros strides' +        reinsertZeros [] [] = [] +        reinsertZeros (False : _) [] = error $ "simplifyArray2: Internal error: reply strides too short" +        reinsertZeros [] (_:_) = error $ "simplifyArray2: Internal error: reply strides too long" + +  , let unrepSize = product [n | (n, True) <- zip sh replDims] + +  = TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) -> +    k @lenshF +      (Array shF strides1F offset1 vec1) +      (Array shF strides2F offset2 vec2) +      unrepSize +      (\idx -> zipWith3 (\b n i -> if b then n - 1 - i else i) +                        revDims sh (reinsertZeros replDims idx)) +      (\(Array sh' strides' offset' vec') -> +         if sh' /= shF then error $ "simplifyArray2: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show shF ++ ")" +         else arrayRevDims revDims (Array sh (reinsertZeros replDims strides') offset' vec')) +      (\(Array sh' strides' offset' vec') -> +         if | sh' /= init shF -> +                error $ "simplifyArray2: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show shF ++ ")" +            | last replDims -> +                error $ "simplifyArray2: Internal error: reduction reply handler used while inner dimension was unreplicated" +            | otherwise -> +                arrayRevDims (init revDims) (Array (init sh) (reinsertZeros (init replDims) strides') offset' vec')) +  {-# NOINLINE wrapUnary #-}  wrapUnary :: forall a b n. Storable a            => SNat n @@ -418,19 +476,28 @@ vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner          (vectorRedInnerOp sn valconv ptrconv fscale fred arr1)          (Array (init sh2) (init strides2) offset2 vec2)    -- now there is useful dotprod work along the inner dimension -  | otherwise = unsafePerformIO $ do -      let inrank = fromSNat' sn + 1 -      outv <- VSM.unsafeNew (product (init sh1)) -      VSM.unsafeWith outv $ \poutv -> -        VS.unsafeWith (VS.fromListN inrank (map fromIntegral sh1)) $ \psh -> -        VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides1)) $ \pstrides1 -> -        VS.unsafeWith vec1 $ \pvec1 -> -        VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides2)) $ \pstrides2 -> -        VS.unsafeWith vec2 $ \pvec2 -> -          fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv) -                    pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1)) -                    pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2)) -      arrayFromVector @_ @n (init sh1) <$> VS.unsafeFreeze outv +  | otherwise = +      simplifyArray2 arr1 arr2 $ \(Array sh' strides1' offset1' vec1' :: Array n' a) (Array _ strides2' offset2' vec2') _ _ _ restore -> +      unsafePerformIO $ do +        let inrank = length sh' +        outv <- VSM.unsafeNew (product (init sh')) +        VSM.unsafeWith outv $ \poutv -> +          VS.unsafeWith (VS.fromListN inrank (map fromIntegral sh')) $ \psh -> +          VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides1')) $ \pstrides1 -> +          VS.unsafeWith vec1' $ \pvec1 -> +          VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides2')) $ \pstrides2 -> +          VS.unsafeWith vec2' $ \pvec2 -> +            fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv) +                      pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1')) +                      pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2')) +        TypeNats.withSomeSNat (fromIntegral (inrank - 1)) $ \(SNat :: SNat n'm1) -> do +          (Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of +                                        LTI -> pure Dict +                                        EQI -> pure Dict +                                        GTI -> error "impossible"  -- because `last strides1 /= 0` +          case sameNat (natSing @(n' - 1)) (natSing @n'm1) of +            Just Refl -> restore . arrayFromVector (init sh') <$> VS.unsafeFreeze outv +            Nothing -> error "impossible"  mulWithInt :: Num a => a -> Int -> a  mulWithInt a i = a * fromIntegral i  | 
