diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-20 13:22:33 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-20 13:22:33 +0100 |
commit | 62724776675488a82f3f372aeb537d97ad91c791 (patch) | |
tree | c071891d473729f68f73f3201bf4f5cfcd199c45 | |
parent | 8d01c5d7d6fba8d7afef1d7bd19d9f3991982032 (diff) |
Compatibility liftVEltwise1 (TODO remove)
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 9402766..b1c7031 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -11,6 +11,12 @@ import Data.Array.Internal.RankedS qualified as RS import Data.Array.Strided qualified as AS import Data.Array.Strided.Arith +-- for liftVEltwise1 +import Foreign.Storable +import GHC.TypeLits +import Data.Vector.Storable qualified as VS +import Data.Array.Strided.Arith.Internal (stridesDense) + fromO :: RS.Array n a -> AS.Array n a fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset vec @@ -25,3 +31,13 @@ liftO1 f = toO . f . fromO liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c) -> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c liftO2 f x y = toO (f (fromO x) (fromO y)) + +liftVEltwise1 :: (Storable a, Storable b) + => SNat n + -> (VS.Vector a -> VS.Vector b) + -> RS.Array n a -> RS.Array n b +liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) + | Just (blockOff, blockSz) <- stridesDense sh offset strides = + let vec' = f (VS.slice blockOff blockSz vec) + in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec')) + | otherwise = RS.fromVector sh (f (RS.toVector arr)) |