diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-03-20 13:22:33 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-20 13:22:33 +0100 | 
| commit | 62724776675488a82f3f372aeb537d97ad91c791 (patch) | |
| tree | c071891d473729f68f73f3201bf4f5cfcd199c45 | |
| parent | 8d01c5d7d6fba8d7afef1d7bd19d9f3991982032 (diff) | |
Compatibility liftVEltwise1 (TODO remove)
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 16 | 
1 files changed, 16 insertions, 0 deletions
| diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 9402766..b1c7031 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -11,6 +11,12 @@ import Data.Array.Internal.RankedS qualified as RS  import Data.Array.Strided qualified as AS  import Data.Array.Strided.Arith +-- for liftVEltwise1 +import Foreign.Storable +import GHC.TypeLits +import Data.Vector.Storable qualified as VS +import Data.Array.Strided.Arith.Internal (stridesDense) +  fromO :: RS.Array n a -> AS.Array n a  fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset vec @@ -25,3 +31,13 @@ liftO1 f = toO . f . fromO  liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c)         -> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c  liftO2 f x y = toO (f (fromO x) (fromO y)) + +liftVEltwise1 :: (Storable a, Storable b) +              => SNat n +              -> (VS.Vector a -> VS.Vector b) +              -> RS.Array n a -> RS.Array n b +liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) +  | Just (blockOff, blockSz) <- stridesDense sh offset strides = +      let vec' = f (VS.slice blockOff blockSz vec) +      in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec')) +  | otherwise = RS.fromVector sh (f (RS.toVector arr)) | 
