aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-06-10 22:05:41 +0200
committerTom Smeding <tom@tomsmeding.com>2024-06-10 22:05:41 +0200
commit19e24848768a80fd971292bf93b9d1769a5118e4 (patch)
tree55b52682a6dd58291e22c846b7d0e39f4ae05040
parenta31367cc657198237a8ff911c8c78a399d51e2b8 (diff)
Make arith code aware of negative strides
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs139
1 files changed, 93 insertions, 46 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index 6ecbbeb..98acfb9 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -3,6 +3,7 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
@@ -13,6 +14,7 @@ import Control.Monad (forM, guard)
import Data.Array.Internal qualified as OI
import Data.Array.Internal.RankedG qualified as RG
import Data.Array.Internal.RankedS qualified as RS
+import Data.Bifunctor (second)
import Data.Bits
import Data.Int
import Data.List (sort)
@@ -31,14 +33,15 @@ import Data.Array.Mixed.Internal.Arith.Foreign
import Data.Array.Mixed.Internal.Arith.Lists
+-- TODO: test all the cases of this thing with various input strides
liftVEltwise1 :: Storable a
=> SNat n
-> (VS.Vector a -> VS.Vector a)
-> RS.Array n a -> RS.Array n a
liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
- | Just prefixSz <- stridesDense sh strides =
- let vec' = f (VS.slice offset prefixSz vec)
- in RS.A (RG.A sh (OI.T strides 0 vec'))
+ | Just (blockOff, blockSz) <- stridesDense sh offset strides =
+ let vec' = f (VS.slice blockOff blockSz vec)
+ in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec'))
| otherwise = RS.fromVector sh (f (RS.toVector arr))
-- TODO: test all the cases of this thing with various input strides
@@ -51,43 +54,58 @@ liftVEltwise2 SNat f
arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2)))
| sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2
| product sh1 == 0 = arr1 -- if the arrays are empty, just return one of the empty inputs
- | otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of
- (Just 1, Just 1) -> -- both are a (potentially replicated) scalar; just apply f to the scalars
+ | otherwise = case (stridesDense sh1 offset1 strides1, stridesDense sh2 offset2 strides2) of
+ (Just (_, 1), Just (_, 1)) -> -- both are a (potentially replicated) scalar; just apply f to the scalars
let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2))
in RS.A (RG.A sh1 (OI.T strides1 0 vec'))
- (Just 1, Just n) -> -- scalar * dense
- RS.A (RG.A sh1 (OI.T strides2 0 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2)))))
- (Just n, Just 1) -> -- dense * scalar
- RS.A (RG.A sh1 (OI.T strides1 0 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2)))))
- (Just n, Just m)
- | n == m -- not sure if this check is necessary
+ (Just (_, 1), Just (blockOff, blockSz)) -> -- scalar * dense
+ RS.A (RG.A sh1 (OI.T strides2 (offset2 - blockOff)
+ (f (Left (vec1 VS.! offset1)) (Right (VS.slice blockOff blockSz vec2)))))
+ (Just (blockOff, blockSz), Just (_, 1)) -> -- dense * scalar
+ RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff)
+ (f (Right (VS.slice blockOff blockSz vec1)) (Left (vec2 VS.! offset2)))))
+ (Just (blockOff1, blockSz1), Just (blockOff2, blockSz2))
+ | blockSz1 == blockSz2 -- not sure if this check is necessary, might be implied by the below
, strides1 == strides2
-> -- dense * dense but the strides match
- RS.A (RG.A sh1 (OI.T strides1 0 (f (Right (VS.slice offset1 n vec1)) (Right (VS.slice offset2 n vec2)))))
+ RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff1)
+ (f (Right (VS.slice blockOff1 blockSz1 vec1)) (Right (VS.slice blockOff2 blockSz2 vec2)))))
(_, _) -> -- fallback case
RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2)))
--- | Given the shape vector and the stride vector, return whether this vector
--- of strides uses a dense prefix of its backing array. If so, the number of
--- elements in this prefix is returned.
+-- | Given shape vector, offset and stride vector, check whether this virtual
+-- vector uses a dense subarray of its backing array. If so, the first index
+-- and the number of elements in this subarray is returned.
-- This excludes any offset.
-stridesDense :: [Int] -> [Int] -> Maybe Int
-stridesDense sh _ | any (<= 0) sh = Just 0
-stridesDense sh str =
- -- sort dimensions on their stride, ascending, dropping any zero strides
- case dropWhile ((== 0) . fst) (sort (zip str sh)) of
- [] -> Just 1
- (1, n) : (unzip -> (str', sh')) -> checkCover n sh' str'
- _ -> Nothing -- if the smallest stride is not 1, it will never be dense
+stridesDense :: [Int] -> Int -> [Int] -> Maybe (Int, Int)
+stridesDense sh offset _ | any (<= 0) sh = Just (offset, 0)
+stridesDense sh offsetNeg stridesNeg =
+ -- First reverse all dimensions with negative stride, so that the first used
+ -- value is at 'offset' and the rest is >= offset.
+ let (offset, strides) = flipReverseds sh offsetNeg stridesNeg
+ in -- sort dimensions on their stride, ascending, dropping any zero strides
+ case filter ((/= 0) . fst) (sort (zip strides sh)) of
+ [] -> Just (offset, 1)
+ (1, n) : pairs -> (offset,) <$> checkCover n pairs
+ _ -> Nothing -- if the smallest stride is not 1, it will never be dense
where
-- Given size of currently densely covered region at beginning of the
- -- array, the remaining shape vector and the corresponding remaining stride
- -- vector, return whether this all together covers a dense prefix of the
- -- array. If it does, return the number of elements in this prefix.
- checkCover :: Int -> [Int] -> [Int] -> Maybe Int
- checkCover block [] [] = Just block
- checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str'
- checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths"
+ -- array and the remaining (stride, size) pairs with all strides >=1,
+ -- return whether this all together covers a dense prefix of the array. If
+ -- it does, return the number of elements in this prefix.
+ checkCover :: Int -> [(Int, Int)] -> Maybe Int
+ checkCover block [] = Just block
+ checkCover block ((s, n) : pairs) = guard (s <= block) >> checkCover (max block (n * s)) pairs
+
+ -- Given shape, offset and strides, returns new (offset, strides) such that all strides are >=0
+ flipReverseds :: [Int] -> Int -> [Int] -> (Int, [Int])
+ flipReverseds [] off [] = (off, [])
+ flipReverseds (n : sh') off (s : str')
+ | s >= 0 = second (s :) (flipReverseds sh' off str')
+ | otherwise =
+ let off' = off + (n - 1) * s
+ in second ((-s) :) (flipReverseds sh' off' str')
+ flipReverseds _ _ _ = error "flipReverseds: invalid arguments"
{-# NOINLINE vectorOp1 #-}
vectorOp1 :: forall a b. Storable a
@@ -141,6 +159,7 @@ vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases
VS.unsafeFreeze outv
| otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy)
+-- TODO: test handling of negative strides
-- | Reduce along the inner dimension
{-# NOINLINE vectorRedInnerOp #-}
vectorRedInnerOp :: forall a b n. (Num a, Storable a)
@@ -171,18 +190,28 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride
-- replace replicated dimensions with ones
shOnes = zipWith (\n repl -> if repl then 1 else n) sh replDims
ndimsF = length shF -- > 0, otherwise `last strides == 0`
+
+ -- reversed dimensions: dimensions with negative stride. Reversal is
+ -- irrelevant for a reduction, and indeed the kernel has a
+ -- precondition that there are no such dimensions.
+ revDims = map (< 0) stridesF
+ stridesR = map abs stridesF
+ offsetR = offset + sum (zipWith3 (\rev n s -> if rev then (n - 1) * s else 0) revDims shF stridesF)
+ -- The *R values give an array with strides all > 0, hence the
+ -- left-most element is at offsetR.
in unsafePerformIO $ do
- outv <- VSM.unsafeNew (product (init shF))
- VSM.unsafeWith outv $ \poutv ->
+ outvR <- VSM.unsafeNew (product (init shF))
+ VSM.unsafeWith outvR $ \poutvR ->
VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF ->
- VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF ->
- VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec ->
- fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec)
+ VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR ->
+ VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR ->
+ fred (fromIntegral ndimsF) pshF pstridesR (ptrconv poutvR) (ptrconv pvecR)
TypeNats.withSomeSNat (fromIntegral (ndimsF - 1)) $ \(SNat :: SNat lenFm1) ->
- RS.stretch (init sh)
- . RS.reshape (init shOnes)
- . RS.fromVector @_ @lenFm1 (init shF)
- <$> VS.unsafeFreeze outv
+ RS.stretch (init sh) -- replicate to original shape
+ . RS.reshape (init shOnes) -- add 1-sized dimensions where the original was replicated
+ . RS.rev (map fst (filter snd (zip [0..] revDims))) -- re-reverse the correct dimensions
+ . RS.fromVector @_ @lenFm1 (init shF) -- the partially-reversed result array
+ <$> VS.unsafeFreeze outvR
-- TODO: test this function
-- | Find extremum (minindex ("argmin") or maxindex) in full array
@@ -195,7 +224,8 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))
| null sh = []
| any (<= 0) sh = error "Extremum (minindex/maxindex): empty array"
-- now the input array is nonempty
- | all (<= 0) strides = 0 <$ sh
+ | all (== 0) strides = 0 <$ sh
+ -- now there is at least one non-replicated dimension
| otherwise =
let -- replicated dimensions: dimensions with zero stride. The extremum
-- kernel need not concern itself with those (and in fact has a
@@ -204,19 +234,30 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))
-- filter out replicated dimensions
(shF, stridesF) = unzip $ map fst $ filter (not . snd) (zip (zip sh strides) replDims)
ndimsF = length shF -- > 0, because not all strides were <=0
+
+ -- un-reverse reversed dimensions
+ revDims = map (< 0) stridesF
+ stridesR = map abs stridesF
+ offsetR = offset + sum (zipWith3 (\rev n s -> if rev then (n - 1) * s else 0) revDims shF stridesF)
+
-- function to insert zeros in replicated-out dimensions
+ insertZeros :: [Bool] -> [Int] -> [Int]
insertZeros [] idx = idx
insertZeros (True : repls) idx = 0 : insertZeros repls idx
insertZeros (False : repls) (i : idx) = i : insertZeros repls idx
insertZeros (_:_) [] = error "unreachable"
in unsafePerformIO $ do
- outv <- VSM.unsafeNew (length shF)
- VSM.unsafeWith outv $ \poutv ->
+ outvR <- VSM.unsafeNew (length shF)
+ VSM.unsafeWith outvR $ \poutvR ->
VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF ->
- VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF ->
- VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec ->
- fextrem poutv (fromIntegral ndimsF) pshF pstridesF (ptrconv pvec)
- insertZeros replDims . map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outv
+ VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR ->
+ VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR ->
+ fextrem poutvR (fromIntegral ndimsF) pshF pstridesR (ptrconv pvecR)
+ insertZeros replDims
+ . zipWith3 (\rev n i -> if rev then n - 1 - i else i) revDims shF -- re-reverse the reversed dimensions
+ . map (fromIntegral @Int64 @Int)
+ . VS.toList
+ <$> VS.unsafeFreeze outvR
vectorDotprodOp :: (Num a, Storable a)
=> (b -> a)
@@ -235,10 +276,16 @@ vectorDotprodOp valbackconv ptrconv fred fdot fdotstrided
fromIntegral len1 * (vec1 VS.! offset1) * (vec2 VS.! offset2)
(0, 1) -> -- replicated scalar * dense
dotScalarVector len1 ptrconv fred (vec1 VS.! offset1) (VS.slice offset2 len1 vec2)
+ (0, -1) -> -- replicated scalar * reversed dense
+ dotScalarVector len1 ptrconv fred (vec1 VS.! offset1) (VS.slice (offset2 - (len1 - 1)) len1 vec2)
(1, 0) -> -- dense * replicated scalar
dotScalarVector len1 ptrconv fred (vec2 VS.! offset2) (VS.slice offset1 len1 vec1)
+ (-1, 0) -> -- reversed dense * replicated scalar
+ dotScalarVector len1 ptrconv fred (vec2 VS.! offset2) (VS.slice (offset1 - (len1 - 1)) len1 vec1)
(1, 1) -> -- dense * dense
dotVectorVector len1 valbackconv ptrconv fdot (VS.slice offset1 len1 vec1) (VS.slice offset2 len1 vec2)
+ (-1, -1) -> -- reversed dense * reversed dense
+ dotVectorVector len1 valbackconv ptrconv fdot (VS.slice (offset1 - (len1 - 1)) len1 vec1) (VS.slice (offset2 - (len1 - 1)) len1 vec2)
(_, _) -> -- fallback case
dotVectorVectorStrided len1 valbackconv ptrconv fdotstrided offset1 stride1 vec1 offset2 stride2 vec2
vectorDotprodOp _ _ _ _ _ _ _ = error "vectorDotprodOp: not one-dimensional?"