diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-06-10 22:05:41 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-06-10 22:05:41 +0200 | 
| commit | 19e24848768a80fd971292bf93b9d1769a5118e4 (patch) | |
| tree | 55b52682a6dd58291e22c846b7d0e39f4ae05040 | |
| parent | a31367cc657198237a8ff911c8c78a399d51e2b8 (diff) | |
Make arith code aware of negative strides
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 139 | 
1 files changed, 93 insertions, 46 deletions
| diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 6ecbbeb..98acfb9 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -3,6 +3,7 @@  {-# LANGUAGE LambdaCase #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TupleSections #-}  {-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeOperators #-}  {-# LANGUAGE ViewPatterns #-} @@ -13,6 +14,7 @@ import Control.Monad (forM, guard)  import Data.Array.Internal qualified as OI  import Data.Array.Internal.RankedG qualified as RG  import Data.Array.Internal.RankedS qualified as RS +import Data.Bifunctor (second)  import Data.Bits  import Data.Int  import Data.List (sort) @@ -31,14 +33,15 @@ import Data.Array.Mixed.Internal.Arith.Foreign  import Data.Array.Mixed.Internal.Arith.Lists +-- TODO: test all the cases of this thing with various input strides  liftVEltwise1 :: Storable a                => SNat n                -> (VS.Vector a -> VS.Vector a)                -> RS.Array n a -> RS.Array n a  liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) -  | Just prefixSz <- stridesDense sh strides = -      let vec' = f (VS.slice offset prefixSz vec) -      in RS.A (RG.A sh (OI.T strides 0 vec')) +  | Just (blockOff, blockSz) <- stridesDense sh offset strides = +      let vec' = f (VS.slice blockOff blockSz vec) +      in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec'))    | otherwise = RS.fromVector sh (f (RS.toVector arr))  -- TODO: test all the cases of this thing with various input strides @@ -51,43 +54,58 @@ liftVEltwise2 SNat f      arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2)))    | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2    | product sh1 == 0 = arr1  -- if the arrays are empty, just return one of the empty inputs -  | otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of -      (Just 1, Just 1) ->  -- both are a (potentially replicated) scalar; just apply f to the scalars +  | otherwise = case (stridesDense sh1 offset1 strides1, stridesDense sh2 offset2 strides2) of +      (Just (_, 1), Just (_, 1)) ->  -- both are a (potentially replicated) scalar; just apply f to the scalars          let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2))          in RS.A (RG.A sh1 (OI.T strides1 0 vec')) -      (Just 1, Just n) ->  -- scalar * dense -        RS.A (RG.A sh1 (OI.T strides2 0 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2))))) -      (Just n, Just 1) ->  -- dense * scalar -        RS.A (RG.A sh1 (OI.T strides1 0 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2))))) -      (Just n, Just m) -        | n == m  -- not sure if this check is necessary +      (Just (_, 1), Just (blockOff, blockSz)) ->  -- scalar * dense +        RS.A (RG.A sh1 (OI.T strides2 (offset2 - blockOff) +                             (f (Left (vec1 VS.! offset1)) (Right (VS.slice blockOff blockSz vec2))))) +      (Just (blockOff, blockSz), Just (_, 1)) ->  -- dense * scalar +        RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff) +                             (f (Right (VS.slice blockOff blockSz vec1)) (Left (vec2 VS.! offset2))))) +      (Just (blockOff1, blockSz1), Just (blockOff2, blockSz2)) +        | blockSz1 == blockSz2  -- not sure if this check is necessary, might be implied by the below          , strides1 == strides2          ->  -- dense * dense but the strides match -          RS.A (RG.A sh1 (OI.T strides1 0 (f (Right (VS.slice offset1 n vec1)) (Right (VS.slice offset2 n vec2))))) +          RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff1) +                               (f (Right (VS.slice blockOff1 blockSz1 vec1)) (Right (VS.slice blockOff2 blockSz2 vec2)))))        (_, _) ->  -- fallback case          RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2))) --- | Given the shape vector and the stride vector, return whether this vector --- of strides uses a dense prefix of its backing array. If so, the number of --- elements in this prefix is returned. +-- | Given shape vector, offset and stride vector, check whether this virtual +-- vector uses a dense subarray of its backing array. If so, the first index +-- and the number of elements in this subarray is returned.  -- This excludes any offset. -stridesDense :: [Int] -> [Int] -> Maybe Int -stridesDense sh _ | any (<= 0) sh = Just 0 -stridesDense sh str = -  -- sort dimensions on their stride, ascending, dropping any zero strides -  case dropWhile ((== 0) . fst) (sort (zip str sh)) of -    [] -> Just 1 -    (1, n) : (unzip -> (str', sh')) -> checkCover n sh' str' -    _ -> Nothing  -- if the smallest stride is not 1, it will never be dense +stridesDense :: [Int] -> Int -> [Int] -> Maybe (Int, Int) +stridesDense sh offset _ | any (<= 0) sh = Just (offset, 0) +stridesDense sh offsetNeg stridesNeg = +  -- First reverse all dimensions with negative stride, so that the first used +  -- value is at 'offset' and the rest is >= offset. +  let (offset, strides) = flipReverseds sh offsetNeg stridesNeg +  in -- sort dimensions on their stride, ascending, dropping any zero strides +     case filter ((/= 0) . fst) (sort (zip strides sh)) of +       [] -> Just (offset, 1) +       (1, n) : pairs -> (offset,) <$> checkCover n pairs +       _ -> Nothing  -- if the smallest stride is not 1, it will never be dense    where      -- Given size of currently densely covered region at beginning of the -    -- array, the remaining shape vector and the corresponding remaining stride -    -- vector, return whether this all together covers a dense prefix of the -    -- array. If it does, return the number of elements in this prefix. -    checkCover :: Int -> [Int] -> [Int] -> Maybe Int -    checkCover block [] [] = Just block -    checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str' -    checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths" +    -- array and the remaining (stride, size) pairs with all strides >=1, +    -- return whether this all together covers a dense prefix of the array. If +    -- it does, return the number of elements in this prefix. +    checkCover :: Int -> [(Int, Int)] -> Maybe Int +    checkCover block [] = Just block +    checkCover block ((s, n) : pairs) = guard (s <= block) >> checkCover (max block (n * s)) pairs + +    -- Given shape, offset and strides, returns new (offset, strides) such that all strides are >=0 +    flipReverseds :: [Int] -> Int -> [Int] -> (Int, [Int]) +    flipReverseds [] off [] = (off, []) +    flipReverseds (n : sh') off (s : str') +      | s >= 0 = second (s :) (flipReverseds sh' off str') +      | otherwise = +          let off' = off + (n - 1) * s +          in second ((-s) :) (flipReverseds sh' off' str') +    flipReverseds _ _ _ = error "flipReverseds: invalid arguments"  {-# NOINLINE vectorOp1 #-}  vectorOp1 :: forall a b. Storable a @@ -141,6 +159,7 @@ vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases            VS.unsafeFreeze outv      | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy) +-- TODO: test handling of negative strides  -- | Reduce along the inner dimension  {-# NOINLINE vectorRedInnerOp #-}  vectorRedInnerOp :: forall a b n. (Num a, Storable a) @@ -171,18 +190,28 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride            -- replace replicated dimensions with ones            shOnes = zipWith (\n repl -> if repl then 1 else n) sh replDims            ndimsF = length shF  -- > 0, otherwise `last strides == 0` + +          -- reversed dimensions: dimensions with negative stride. Reversal is +          -- irrelevant for a reduction, and indeed the kernel has a +          -- precondition that there are no such dimensions. +          revDims = map (< 0) stridesF +          stridesR = map abs stridesF +          offsetR = offset + sum (zipWith3 (\rev n s -> if rev then (n - 1) * s else 0) revDims shF stridesF) +          -- The *R values give an array with strides all > 0, hence the +          -- left-most element is at offsetR.        in unsafePerformIO $ do -           outv <- VSM.unsafeNew (product (init shF)) -           VSM.unsafeWith outv $ \poutv -> +           outvR <- VSM.unsafeNew (product (init shF)) +           VSM.unsafeWith outvR $ \poutvR ->               VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> -               VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF -> -                 VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec -> -                   fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec) +               VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR -> +                 VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR -> +                   fred (fromIntegral ndimsF) pshF pstridesR (ptrconv poutvR) (ptrconv pvecR)             TypeNats.withSomeSNat (fromIntegral (ndimsF - 1)) $ \(SNat :: SNat lenFm1) -> -             RS.stretch (init sh) -               . RS.reshape (init shOnes) -               . RS.fromVector @_ @lenFm1 (init shF) -               <$> VS.unsafeFreeze outv +             RS.stretch (init sh)  -- replicate to original shape +               . RS.reshape (init shOnes)  -- add 1-sized dimensions where the original was replicated +               . RS.rev (map fst (filter snd (zip [0..] revDims)))  -- re-reverse the correct dimensions +               . RS.fromVector @_ @lenFm1 (init shF)  -- the partially-reversed result array +               <$> VS.unsafeFreeze outvR  -- TODO: test this function  -- | Find extremum (minindex ("argmin") or maxindex) in full array @@ -195,7 +224,8 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))    | null sh = []    | any (<= 0) sh = error "Extremum (minindex/maxindex): empty array"    -- now the input array is nonempty -  | all (<= 0) strides = 0 <$ sh +  | all (== 0) strides = 0 <$ sh +  -- now there is at least one non-replicated dimension    | otherwise =        let -- replicated dimensions: dimensions with zero stride. The extremum            -- kernel need not concern itself with those (and in fact has a @@ -204,19 +234,30 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))            -- filter out replicated dimensions            (shF, stridesF) = unzip $ map fst $ filter (not . snd) (zip (zip sh strides) replDims)            ndimsF = length shF  -- > 0, because not all strides were <=0 + +          -- un-reverse reversed dimensions +          revDims = map (< 0) stridesF +          stridesR = map abs stridesF +          offsetR = offset + sum (zipWith3 (\rev n s -> if rev then (n - 1) * s else 0) revDims shF stridesF) +            -- function to insert zeros in replicated-out dimensions +          insertZeros :: [Bool] -> [Int] -> [Int]            insertZeros [] idx = idx            insertZeros (True : repls) idx = 0 : insertZeros repls idx            insertZeros (False : repls) (i : idx) = i : insertZeros repls idx            insertZeros (_:_) [] = error "unreachable"        in unsafePerformIO $ do -           outv <- VSM.unsafeNew (length shF) -           VSM.unsafeWith outv $ \poutv -> +           outvR <- VSM.unsafeNew (length shF) +           VSM.unsafeWith outvR $ \poutvR ->               VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> -               VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF -> -                 VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec -> -                   fextrem poutv (fromIntegral ndimsF) pshF pstridesF (ptrconv pvec) -           insertZeros replDims . map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outv +               VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR -> +                 VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR -> +                   fextrem poutvR (fromIntegral ndimsF) pshF pstridesR (ptrconv pvecR) +           insertZeros replDims +             . zipWith3 (\rev n i -> if rev then n - 1 - i else i) revDims shF  -- re-reverse the reversed dimensions +             . map (fromIntegral @Int64 @Int) +             . VS.toList +             <$> VS.unsafeFreeze outvR  vectorDotprodOp :: (Num a, Storable a)                  => (b -> a) @@ -235,10 +276,16 @@ vectorDotprodOp valbackconv ptrconv fred fdot fdotstrided          fromIntegral len1 * (vec1 VS.! offset1) * (vec2 VS.! offset2)        (0, 1) ->  -- replicated scalar * dense          dotScalarVector len1 ptrconv fred (vec1 VS.! offset1) (VS.slice offset2 len1 vec2) +      (0, -1) ->  -- replicated scalar * reversed dense +        dotScalarVector len1 ptrconv fred (vec1 VS.! offset1) (VS.slice (offset2 - (len1 - 1)) len1 vec2)        (1, 0) ->  -- dense * replicated scalar          dotScalarVector len1 ptrconv fred (vec2 VS.! offset2) (VS.slice offset1 len1 vec1) +      (-1, 0) ->  -- reversed dense * replicated scalar +        dotScalarVector len1 ptrconv fred (vec2 VS.! offset2) (VS.slice (offset1 - (len1 - 1)) len1 vec1)        (1, 1) ->  -- dense * dense          dotVectorVector len1 valbackconv ptrconv fdot (VS.slice offset1 len1 vec1) (VS.slice offset2 len1 vec2) +      (-1, -1) ->  -- reversed dense * reversed dense +        dotVectorVector len1 valbackconv ptrconv fdot (VS.slice (offset1 - (len1 - 1)) len1 vec1) (VS.slice (offset2 - (len1 - 1)) len1 vec2)        (_, _) ->  -- fallback case          dotVectorVectorStrided len1 valbackconv ptrconv fdotstrided offset1 stride1 vec1 offset2 stride2 vec2  vectorDotprodOp _ _ _ _ _ _ _ = error "vectorDotprodOp: not one-dimensional?" | 
