From 19e24848768a80fd971292bf93b9d1769a5118e4 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 10 Jun 2024 22:05:41 +0200 Subject: Make arith code aware of negative strides --- src/Data/Array/Mixed/Internal/Arith.hs | 139 ++++++++++++++++++++++----------- 1 file changed, 93 insertions(+), 46 deletions(-) (limited to 'src/Data/Array/Mixed/Internal') diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 6ecbbeb..98acfb9 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -3,6 +3,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} @@ -13,6 +14,7 @@ import Control.Monad (forM, guard) import Data.Array.Internal qualified as OI import Data.Array.Internal.RankedG qualified as RG import Data.Array.Internal.RankedS qualified as RS +import Data.Bifunctor (second) import Data.Bits import Data.Int import Data.List (sort) @@ -31,14 +33,15 @@ import Data.Array.Mixed.Internal.Arith.Foreign import Data.Array.Mixed.Internal.Arith.Lists +-- TODO: test all the cases of this thing with various input strides liftVEltwise1 :: Storable a => SNat n -> (VS.Vector a -> VS.Vector a) -> RS.Array n a -> RS.Array n a liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) - | Just prefixSz <- stridesDense sh strides = - let vec' = f (VS.slice offset prefixSz vec) - in RS.A (RG.A sh (OI.T strides 0 vec')) + | Just (blockOff, blockSz) <- stridesDense sh offset strides = + let vec' = f (VS.slice blockOff blockSz vec) + in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec')) | otherwise = RS.fromVector sh (f (RS.toVector arr)) -- TODO: test all the cases of this thing with various input strides @@ -51,43 +54,58 @@ liftVEltwise2 SNat f arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 | product sh1 == 0 = arr1 -- if the arrays are empty, just return one of the empty inputs - | otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of - (Just 1, Just 1) -> -- both are a (potentially replicated) scalar; just apply f to the scalars + | otherwise = case (stridesDense sh1 offset1 strides1, stridesDense sh2 offset2 strides2) of + (Just (_, 1), Just (_, 1)) -> -- both are a (potentially replicated) scalar; just apply f to the scalars let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2)) in RS.A (RG.A sh1 (OI.T strides1 0 vec')) - (Just 1, Just n) -> -- scalar * dense - RS.A (RG.A sh1 (OI.T strides2 0 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2))))) - (Just n, Just 1) -> -- dense * scalar - RS.A (RG.A sh1 (OI.T strides1 0 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2))))) - (Just n, Just m) - | n == m -- not sure if this check is necessary + (Just (_, 1), Just (blockOff, blockSz)) -> -- scalar * dense + RS.A (RG.A sh1 (OI.T strides2 (offset2 - blockOff) + (f (Left (vec1 VS.! offset1)) (Right (VS.slice blockOff blockSz vec2))))) + (Just (blockOff, blockSz), Just (_, 1)) -> -- dense * scalar + RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff) + (f (Right (VS.slice blockOff blockSz vec1)) (Left (vec2 VS.! offset2))))) + (Just (blockOff1, blockSz1), Just (blockOff2, blockSz2)) + | blockSz1 == blockSz2 -- not sure if this check is necessary, might be implied by the below , strides1 == strides2 -> -- dense * dense but the strides match - RS.A (RG.A sh1 (OI.T strides1 0 (f (Right (VS.slice offset1 n vec1)) (Right (VS.slice offset2 n vec2))))) + RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff1) + (f (Right (VS.slice blockOff1 blockSz1 vec1)) (Right (VS.slice blockOff2 blockSz2 vec2))))) (_, _) -> -- fallback case RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2))) --- | Given the shape vector and the stride vector, return whether this vector --- of strides uses a dense prefix of its backing array. If so, the number of --- elements in this prefix is returned. +-- | Given shape vector, offset and stride vector, check whether this virtual +-- vector uses a dense subarray of its backing array. If so, the first index +-- and the number of elements in this subarray is returned. -- This excludes any offset. -stridesDense :: [Int] -> [Int] -> Maybe Int -stridesDense sh _ | any (<= 0) sh = Just 0 -stridesDense sh str = - -- sort dimensions on their stride, ascending, dropping any zero strides - case dropWhile ((== 0) . fst) (sort (zip str sh)) of - [] -> Just 1 - (1, n) : (unzip -> (str', sh')) -> checkCover n sh' str' - _ -> Nothing -- if the smallest stride is not 1, it will never be dense +stridesDense :: [Int] -> Int -> [Int] -> Maybe (Int, Int) +stridesDense sh offset _ | any (<= 0) sh = Just (offset, 0) +stridesDense sh offsetNeg stridesNeg = + -- First reverse all dimensions with negative stride, so that the first used + -- value is at 'offset' and the rest is >= offset. + let (offset, strides) = flipReverseds sh offsetNeg stridesNeg + in -- sort dimensions on their stride, ascending, dropping any zero strides + case filter ((/= 0) . fst) (sort (zip strides sh)) of + [] -> Just (offset, 1) + (1, n) : pairs -> (offset,) <$> checkCover n pairs + _ -> Nothing -- if the smallest stride is not 1, it will never be dense where -- Given size of currently densely covered region at beginning of the - -- array, the remaining shape vector and the corresponding remaining stride - -- vector, return whether this all together covers a dense prefix of the - -- array. If it does, return the number of elements in this prefix. - checkCover :: Int -> [Int] -> [Int] -> Maybe Int - checkCover block [] [] = Just block - checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str' - checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths" + -- array and the remaining (stride, size) pairs with all strides >=1, + -- return whether this all together covers a dense prefix of the array. If + -- it does, return the number of elements in this prefix. + checkCover :: Int -> [(Int, Int)] -> Maybe Int + checkCover block [] = Just block + checkCover block ((s, n) : pairs) = guard (s <= block) >> checkCover (max block (n * s)) pairs + + -- Given shape, offset and strides, returns new (offset, strides) such that all strides are >=0 + flipReverseds :: [Int] -> Int -> [Int] -> (Int, [Int]) + flipReverseds [] off [] = (off, []) + flipReverseds (n : sh') off (s : str') + | s >= 0 = second (s :) (flipReverseds sh' off str') + | otherwise = + let off' = off + (n - 1) * s + in second ((-s) :) (flipReverseds sh' off' str') + flipReverseds _ _ _ = error "flipReverseds: invalid arguments" {-# NOINLINE vectorOp1 #-} vectorOp1 :: forall a b. Storable a @@ -141,6 +159,7 @@ vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases VS.unsafeFreeze outv | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy) +-- TODO: test handling of negative strides -- | Reduce along the inner dimension {-# NOINLINE vectorRedInnerOp #-} vectorRedInnerOp :: forall a b n. (Num a, Storable a) @@ -171,18 +190,28 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride -- replace replicated dimensions with ones shOnes = zipWith (\n repl -> if repl then 1 else n) sh replDims ndimsF = length shF -- > 0, otherwise `last strides == 0` + + -- reversed dimensions: dimensions with negative stride. Reversal is + -- irrelevant for a reduction, and indeed the kernel has a + -- precondition that there are no such dimensions. + revDims = map (< 0) stridesF + stridesR = map abs stridesF + offsetR = offset + sum (zipWith3 (\rev n s -> if rev then (n - 1) * s else 0) revDims shF stridesF) + -- The *R values give an array with strides all > 0, hence the + -- left-most element is at offsetR. in unsafePerformIO $ do - outv <- VSM.unsafeNew (product (init shF)) - VSM.unsafeWith outv $ \poutv -> + outvR <- VSM.unsafeNew (product (init shF)) + VSM.unsafeWith outvR $ \poutvR -> VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> - VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF -> - VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec -> - fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec) + VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR -> + VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR -> + fred (fromIntegral ndimsF) pshF pstridesR (ptrconv poutvR) (ptrconv pvecR) TypeNats.withSomeSNat (fromIntegral (ndimsF - 1)) $ \(SNat :: SNat lenFm1) -> - RS.stretch (init sh) - . RS.reshape (init shOnes) - . RS.fromVector @_ @lenFm1 (init shF) - <$> VS.unsafeFreeze outv + RS.stretch (init sh) -- replicate to original shape + . RS.reshape (init shOnes) -- add 1-sized dimensions where the original was replicated + . RS.rev (map fst (filter snd (zip [0..] revDims))) -- re-reverse the correct dimensions + . RS.fromVector @_ @lenFm1 (init shF) -- the partially-reversed result array + <$> VS.unsafeFreeze outvR -- TODO: test this function -- | Find extremum (minindex ("argmin") or maxindex) in full array @@ -195,7 +224,8 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec))) | null sh = [] | any (<= 0) sh = error "Extremum (minindex/maxindex): empty array" -- now the input array is nonempty - | all (<= 0) strides = 0 <$ sh + | all (== 0) strides = 0 <$ sh + -- now there is at least one non-replicated dimension | otherwise = let -- replicated dimensions: dimensions with zero stride. The extremum -- kernel need not concern itself with those (and in fact has a @@ -204,19 +234,30 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec))) -- filter out replicated dimensions (shF, stridesF) = unzip $ map fst $ filter (not . snd) (zip (zip sh strides) replDims) ndimsF = length shF -- > 0, because not all strides were <=0 + + -- un-reverse reversed dimensions + revDims = map (< 0) stridesF + stridesR = map abs stridesF + offsetR = offset + sum (zipWith3 (\rev n s -> if rev then (n - 1) * s else 0) revDims shF stridesF) + -- function to insert zeros in replicated-out dimensions + insertZeros :: [Bool] -> [Int] -> [Int] insertZeros [] idx = idx insertZeros (True : repls) idx = 0 : insertZeros repls idx insertZeros (False : repls) (i : idx) = i : insertZeros repls idx insertZeros (_:_) [] = error "unreachable" in unsafePerformIO $ do - outv <- VSM.unsafeNew (length shF) - VSM.unsafeWith outv $ \poutv -> + outvR <- VSM.unsafeNew (length shF) + VSM.unsafeWith outvR $ \poutvR -> VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> - VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF -> - VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec -> - fextrem poutv (fromIntegral ndimsF) pshF pstridesF (ptrconv pvec) - insertZeros replDims . map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outv + VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR -> + VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR -> + fextrem poutvR (fromIntegral ndimsF) pshF pstridesR (ptrconv pvecR) + insertZeros replDims + . zipWith3 (\rev n i -> if rev then n - 1 - i else i) revDims shF -- re-reverse the reversed dimensions + . map (fromIntegral @Int64 @Int) + . VS.toList + <$> VS.unsafeFreeze outvR vectorDotprodOp :: (Num a, Storable a) => (b -> a) @@ -235,10 +276,16 @@ vectorDotprodOp valbackconv ptrconv fred fdot fdotstrided fromIntegral len1 * (vec1 VS.! offset1) * (vec2 VS.! offset2) (0, 1) -> -- replicated scalar * dense dotScalarVector len1 ptrconv fred (vec1 VS.! offset1) (VS.slice offset2 len1 vec2) + (0, -1) -> -- replicated scalar * reversed dense + dotScalarVector len1 ptrconv fred (vec1 VS.! offset1) (VS.slice (offset2 - (len1 - 1)) len1 vec2) (1, 0) -> -- dense * replicated scalar dotScalarVector len1 ptrconv fred (vec2 VS.! offset2) (VS.slice offset1 len1 vec1) + (-1, 0) -> -- reversed dense * replicated scalar + dotScalarVector len1 ptrconv fred (vec2 VS.! offset2) (VS.slice (offset1 - (len1 - 1)) len1 vec1) (1, 1) -> -- dense * dense dotVectorVector len1 valbackconv ptrconv fdot (VS.slice offset1 len1 vec1) (VS.slice offset2 len1 vec2) + (-1, -1) -> -- reversed dense * reversed dense + dotVectorVector len1 valbackconv ptrconv fdot (VS.slice (offset1 - (len1 - 1)) len1 vec1) (VS.slice (offset2 - (len1 - 1)) len1 vec2) (_, _) -> -- fallback case dotVectorVectorStrided len1 valbackconv ptrconv fdotstrided offset1 stride1 vec1 offset2 stride2 vec2 vectorDotprodOp _ _ _ _ _ _ _ = error "vectorDotprodOp: not one-dimensional?" -- cgit v1.2.3-70-g09d2