From a78ddeaa5d34fa8b6fa52eee42977cc46e8c36a5 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Tue, 25 Mar 2025 17:09:20 +0100
Subject: Dotprod: Optimise reversed and replicated dimensions

---
 ops/Data/Array/Strided/Arith/Internal.hs | 97 +++++++++++++++++++++++++++-----
 1 file changed, 82 insertions(+), 15 deletions(-)

(limited to 'ops/Data/Array/Strided/Arith/Internal.hs')

diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs
index a74e43d..313d72f 100644
--- a/ops/Data/Array/Strided/Arith/Internal.hs
+++ b/ops/Data/Array/Strided/Arith/Internal.hs
@@ -18,7 +18,7 @@ import Control.Monad
 import Data.Bifunctor (second)
 import Data.Bits
 import Data.Int
-import Data.List (sort)
+import Data.List (sort, zip4)
 import Data.Proxy
 import Data.Type.Equality
 import qualified Data.Vector.Storable as VS
@@ -184,7 +184,7 @@ unreplicateStrides (Array sh strides offset vec) =
 
 simplifyArray :: Array n a
               -> (forall n'. KnownNat n'
-              => Array n' a  -- U
+                          => Array n' a  -- U
                           -- Product of sizes of the unreplicated dimensions
                           -> Int
                           -- Convert index in U back to index into original
@@ -218,6 +218,64 @@ simplifyArray array k
             | otherwise ->
                 arrayRevDims (init revDims) (Array (init (arrShape array)) (init (rereplicate (strides' ++ [0]))) offset' vec'))
 
+-- | The two input arrays must have the same shape.
+simplifyArray2 :: Array n a -> Array n a
+               -> (forall n'. KnownNat n'
+                           => Array n' a  -- U1
+                           -> Array n' a  -- U2 (same shape as U1)
+                           -- Product of sizes of the dimensions that are
+                           -- replicated in neither input
+                           -> Int
+                           -- Convert index in U{1,2} back to index into original
+                           -- arrays. Dimensions that are replicated in both
+                           -- inputs get 0.
+                           -> ([Int] -> [Int])
+                           -- Given a new array of the same shape as U1 (& U2),
+                           -- convert it back to the original shape and
+                           -- iteration order.
+                           -> (Array n' a -> Array n a)
+                           -- Do the same except without the INNER dimension.
+                           -- This throws an error if the inner dimension had
+                           -- stride 0 in both inputs.
+                           -> (Array (n' - 1) a -> Array (n - 1) a)
+                           -> r)
+               -> r
+simplifyArray2 arr1@(Array sh _ _ _) arr2@(Array sh2 _ _ _) k
+  | sh /= sh2 = error "simplifyArray2: Unequal shapes"
+
+  | let revDims = zipWith (\s1 s2 -> s1 < 0 && s2 < 0) (arrStrides arr1) (arrStrides arr2)
+  , Array _ strides1 offset1 vec1 <- arrayRevDims revDims arr1
+  , Array _ strides2 offset2 vec2 <- arrayRevDims revDims arr2
+
+  , let replDims = zipWith (\s1 s2 -> s1 == 0 && s2 == 0) strides1 strides2
+  , let (shF, strides1F, strides2F) = unzip3 [(n, s1, s2) | (n, s1, s2, False) <- zip4 sh strides1 strides2 replDims]
+
+  , let reinsertZeros (False : zeros) (s : strides') = s : reinsertZeros zeros strides'
+        reinsertZeros (True : zeros) strides' = 0 : reinsertZeros zeros strides'
+        reinsertZeros [] [] = []
+        reinsertZeros (False : _) [] = error $ "simplifyArray2: Internal error: reply strides too short"
+        reinsertZeros [] (_:_) = error $ "simplifyArray2: Internal error: reply strides too long"
+
+  , let unrepSize = product [n | (n, True) <- zip sh replDims]
+
+  = TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
+    k @lenshF
+      (Array shF strides1F offset1 vec1)
+      (Array shF strides2F offset2 vec2)
+      unrepSize
+      (\idx -> zipWith3 (\b n i -> if b then n - 1 - i else i)
+                        revDims sh (reinsertZeros replDims idx))
+      (\(Array sh' strides' offset' vec') ->
+         if sh' /= shF then error $ "simplifyArray2: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show shF ++ ")"
+         else arrayRevDims revDims (Array sh (reinsertZeros replDims strides') offset' vec'))
+      (\(Array sh' strides' offset' vec') ->
+         if | sh' /= init shF ->
+                error $ "simplifyArray2: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show shF ++ ")"
+            | last replDims ->
+                error $ "simplifyArray2: Internal error: reduction reply handler used while inner dimension was unreplicated"
+            | otherwise ->
+                arrayRevDims (init revDims) (Array (init sh) (reinsertZeros (init replDims) strides') offset' vec'))
+
 {-# NOINLINE wrapUnary #-}
 wrapUnary :: forall a b n. Storable a
           => SNat n
@@ -418,19 +476,28 @@ vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner
         (vectorRedInnerOp sn valconv ptrconv fscale fred arr1)
         (Array (init sh2) (init strides2) offset2 vec2)
   -- now there is useful dotprod work along the inner dimension
-  | otherwise = unsafePerformIO $ do
-      let inrank = fromSNat' sn + 1
-      outv <- VSM.unsafeNew (product (init sh1))
-      VSM.unsafeWith outv $ \poutv ->
-        VS.unsafeWith (VS.fromListN inrank (map fromIntegral sh1)) $ \psh ->
-        VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides1)) $ \pstrides1 ->
-        VS.unsafeWith vec1 $ \pvec1 ->
-        VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides2)) $ \pstrides2 ->
-        VS.unsafeWith vec2 $ \pvec2 ->
-          fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv)
-                    pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1))
-                    pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2))
-      arrayFromVector @_ @n (init sh1) <$> VS.unsafeFreeze outv
+  | otherwise =
+      simplifyArray2 arr1 arr2 $ \(Array sh' strides1' offset1' vec1' :: Array n' a) (Array _ strides2' offset2' vec2') _ _ _ restore ->
+      unsafePerformIO $ do
+        let inrank = length sh'
+        outv <- VSM.unsafeNew (product (init sh'))
+        VSM.unsafeWith outv $ \poutv ->
+          VS.unsafeWith (VS.fromListN inrank (map fromIntegral sh')) $ \psh ->
+          VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides1')) $ \pstrides1 ->
+          VS.unsafeWith vec1' $ \pvec1 ->
+          VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides2')) $ \pstrides2 ->
+          VS.unsafeWith vec2' $ \pvec2 ->
+            fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv)
+                      pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1'))
+                      pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2'))
+        TypeNats.withSomeSNat (fromIntegral (inrank - 1)) $ \(SNat :: SNat n'm1) -> do
+          (Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of
+                                        LTI -> pure Dict
+                                        EQI -> pure Dict
+                                        GTI -> error "impossible"  -- because `last strides1 /= 0`
+          case sameNat (natSing @(n' - 1)) (natSing @n'm1) of
+            Just Refl -> restore . arrayFromVector (init sh') <$> VS.unsafeFreeze outv
+            Nothing -> error "impossible"
 
 mulWithInt :: Num a => a -> Int -> a
 mulWithInt a i = a * fromIntegral i
-- 
cgit v1.2.3-70-g09d2