diff options
Diffstat (limited to 'ops/Data/Array/Strided/Arith/Internal.hs')
-rw-r--r-- | ops/Data/Array/Strided/Arith/Internal.hs | 97 |
1 files changed, 82 insertions, 15 deletions
diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs index a74e43d..313d72f 100644 --- a/ops/Data/Array/Strided/Arith/Internal.hs +++ b/ops/Data/Array/Strided/Arith/Internal.hs @@ -18,7 +18,7 @@ import Control.Monad import Data.Bifunctor (second) import Data.Bits import Data.Int -import Data.List (sort) +import Data.List (sort, zip4) import Data.Proxy import Data.Type.Equality import qualified Data.Vector.Storable as VS @@ -184,7 +184,7 @@ unreplicateStrides (Array sh strides offset vec) = simplifyArray :: Array n a -> (forall n'. KnownNat n' - => Array n' a -- U + => Array n' a -- U -- Product of sizes of the unreplicated dimensions -> Int -- Convert index in U back to index into original @@ -218,6 +218,64 @@ simplifyArray array k | otherwise -> arrayRevDims (init revDims) (Array (init (arrShape array)) (init (rereplicate (strides' ++ [0]))) offset' vec')) +-- | The two input arrays must have the same shape. +simplifyArray2 :: Array n a -> Array n a + -> (forall n'. KnownNat n' + => Array n' a -- U1 + -> Array n' a -- U2 (same shape as U1) + -- Product of sizes of the dimensions that are + -- replicated in neither input + -> Int + -- Convert index in U{1,2} back to index into original + -- arrays. Dimensions that are replicated in both + -- inputs get 0. + -> ([Int] -> [Int]) + -- Given a new array of the same shape as U1 (& U2), + -- convert it back to the original shape and + -- iteration order. + -> (Array n' a -> Array n a) + -- Do the same except without the INNER dimension. + -- This throws an error if the inner dimension had + -- stride 0 in both inputs. + -> (Array (n' - 1) a -> Array (n - 1) a) + -> r) + -> r +simplifyArray2 arr1@(Array sh _ _ _) arr2@(Array sh2 _ _ _) k + | sh /= sh2 = error "simplifyArray2: Unequal shapes" + + | let revDims = zipWith (\s1 s2 -> s1 < 0 && s2 < 0) (arrStrides arr1) (arrStrides arr2) + , Array _ strides1 offset1 vec1 <- arrayRevDims revDims arr1 + , Array _ strides2 offset2 vec2 <- arrayRevDims revDims arr2 + + , let replDims = zipWith (\s1 s2 -> s1 == 0 && s2 == 0) strides1 strides2 + , let (shF, strides1F, strides2F) = unzip3 [(n, s1, s2) | (n, s1, s2, False) <- zip4 sh strides1 strides2 replDims] + + , let reinsertZeros (False : zeros) (s : strides') = s : reinsertZeros zeros strides' + reinsertZeros (True : zeros) strides' = 0 : reinsertZeros zeros strides' + reinsertZeros [] [] = [] + reinsertZeros (False : _) [] = error $ "simplifyArray2: Internal error: reply strides too short" + reinsertZeros [] (_:_) = error $ "simplifyArray2: Internal error: reply strides too long" + + , let unrepSize = product [n | (n, True) <- zip sh replDims] + + = TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) -> + k @lenshF + (Array shF strides1F offset1 vec1) + (Array shF strides2F offset2 vec2) + unrepSize + (\idx -> zipWith3 (\b n i -> if b then n - 1 - i else i) + revDims sh (reinsertZeros replDims idx)) + (\(Array sh' strides' offset' vec') -> + if sh' /= shF then error $ "simplifyArray2: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show shF ++ ")" + else arrayRevDims revDims (Array sh (reinsertZeros replDims strides') offset' vec')) + (\(Array sh' strides' offset' vec') -> + if | sh' /= init shF -> + error $ "simplifyArray2: Internal error: reply shape wrong (reply " ++ show sh' ++ ", unreplicated " ++ show shF ++ ")" + | last replDims -> + error $ "simplifyArray2: Internal error: reduction reply handler used while inner dimension was unreplicated" + | otherwise -> + arrayRevDims (init revDims) (Array (init sh) (reinsertZeros (init replDims) strides') offset' vec')) + {-# NOINLINE wrapUnary #-} wrapUnary :: forall a b n. Storable a => SNat n @@ -418,19 +476,28 @@ vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner (vectorRedInnerOp sn valconv ptrconv fscale fred arr1) (Array (init sh2) (init strides2) offset2 vec2) -- now there is useful dotprod work along the inner dimension - | otherwise = unsafePerformIO $ do - let inrank = fromSNat' sn + 1 - outv <- VSM.unsafeNew (product (init sh1)) - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith (VS.fromListN inrank (map fromIntegral sh1)) $ \psh -> - VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides1)) $ \pstrides1 -> - VS.unsafeWith vec1 $ \pvec1 -> - VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides2)) $ \pstrides2 -> - VS.unsafeWith vec2 $ \pvec2 -> - fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv) - pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1)) - pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2)) - arrayFromVector @_ @n (init sh1) <$> VS.unsafeFreeze outv + | otherwise = + simplifyArray2 arr1 arr2 $ \(Array sh' strides1' offset1' vec1' :: Array n' a) (Array _ strides2' offset2' vec2') _ _ _ restore -> + unsafePerformIO $ do + let inrank = length sh' + outv <- VSM.unsafeNew (product (init sh')) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith (VS.fromListN inrank (map fromIntegral sh')) $ \psh -> + VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides1')) $ \pstrides1 -> + VS.unsafeWith vec1' $ \pvec1 -> + VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides2')) $ \pstrides2 -> + VS.unsafeWith vec2' $ \pvec2 -> + fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv) + pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1')) + pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2')) + TypeNats.withSomeSNat (fromIntegral (inrank - 1)) $ \(SNat :: SNat n'm1) -> do + (Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of + LTI -> pure Dict + EQI -> pure Dict + GTI -> error "impossible" -- because `last strides1 /= 0` + case sameNat (natSing @(n' - 1)) (natSing @n'm1) of + Just Refl -> restore . arrayFromVector (init sh') <$> VS.unsafeFreeze outv + Nothing -> error "impossible" mulWithInt :: Num a => a -> Int -> a mulWithInt a i = a * fromIntegral i |