aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-20 13:22:33 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-20 13:22:33 +0100
commit62724776675488a82f3f372aeb537d97ad91c791 (patch)
treec071891d473729f68f73f3201bf4f5cfcd199c45
parent8d01c5d7d6fba8d7afef1d7bd19d9f3991982032 (diff)
Compatibility liftVEltwise1 (TODO remove)
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs16
1 files changed, 16 insertions, 0 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index 9402766..b1c7031 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -11,6 +11,12 @@ import Data.Array.Internal.RankedS qualified as RS
import Data.Array.Strided qualified as AS
import Data.Array.Strided.Arith
+-- for liftVEltwise1
+import Foreign.Storable
+import GHC.TypeLits
+import Data.Vector.Storable qualified as VS
+import Data.Array.Strided.Arith.Internal (stridesDense)
+
fromO :: RS.Array n a -> AS.Array n a
fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset vec
@@ -25,3 +31,13 @@ liftO1 f = toO . f . fromO
liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c)
-> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c
liftO2 f x y = toO (f (fromO x) (fromO y))
+
+liftVEltwise1 :: (Storable a, Storable b)
+ => SNat n
+ -> (VS.Vector a -> VS.Vector b)
+ -> RS.Array n a -> RS.Array n b
+liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
+ | Just (blockOff, blockSz) <- stridesDense sh offset strides =
+ let vec' = f (VS.slice blockOff blockSz vec)
+ in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec'))
+ | otherwise = RS.fromVector sh (f (RS.toVector arr))