diff options
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 16 | 
1 files changed, 16 insertions, 0 deletions
| diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 9402766..b1c7031 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -11,6 +11,12 @@ import Data.Array.Internal.RankedS qualified as RS  import Data.Array.Strided qualified as AS  import Data.Array.Strided.Arith +-- for liftVEltwise1 +import Foreign.Storable +import GHC.TypeLits +import Data.Vector.Storable qualified as VS +import Data.Array.Strided.Arith.Internal (stridesDense) +  fromO :: RS.Array n a -> AS.Array n a  fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset vec @@ -25,3 +31,13 @@ liftO1 f = toO . f . fromO  liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c)         -> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c  liftO2 f x y = toO (f (fromO x) (fromO y)) + +liftVEltwise1 :: (Storable a, Storable b) +              => SNat n +              -> (VS.Vector a -> VS.Vector b) +              -> RS.Array n a -> RS.Array n b +liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) +  | Just (blockOff, blockSz) <- stridesDense sh offset strides = +      let vec' = f (VS.slice blockOff blockSz vec) +      in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec')) +  | otherwise = RS.fromVector sh (f (RS.toVector arr)) | 
