aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Arith.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Arith.hs')
-rw-r--r--src/Data/Array/Arith.hs43
1 files changed, 43 insertions, 0 deletions
diff --git a/src/Data/Array/Arith.hs b/src/Data/Array/Arith.hs
new file mode 100644
index 0000000..1eae737
--- /dev/null
+++ b/src/Data/Array/Arith.hs
@@ -0,0 +1,43 @@
+{-# LANGUAGE ImportQualifiedPost #-}
+module Data.Array.Arith (
+ module Data.Array.Arith,
+ module Data.Array.Strided.Arith,
+) where
+
+import Data.Array.Internal qualified as OI
+import Data.Array.Internal.RankedG qualified as RG
+import Data.Array.Internal.RankedS qualified as RS
+
+import Data.Array.Strided qualified as AS
+import Data.Array.Strided.Arith
+
+-- for liftVEltwise1
+import Data.Array.Strided.Arith.Internal (stridesDense)
+import Data.Vector.Storable qualified as VS
+import Foreign.Storable
+import GHC.TypeLits
+
+
+fromO :: RS.Array n a -> AS.Array n a
+fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset vec
+
+toO :: AS.Array n a -> RS.Array n a
+toO (AS.Array sh strides offset vec) = RS.A (RG.A sh (OI.T strides offset vec))
+
+liftO1 :: (AS.Array n a -> AS.Array n' b)
+ -> RS.Array n a -> RS.Array n' b
+liftO1 f = toO . f . fromO
+
+liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c)
+ -> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c
+liftO2 f x y = toO (f (fromO x) (fromO y))
+
+liftVEltwise1 :: (Storable a, Storable b)
+ => SNat n
+ -> (VS.Vector a -> VS.Vector b)
+ -> RS.Array n a -> RS.Array n b
+liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
+ | Just (blockOff, blockSz) <- stridesDense sh offset strides =
+ let vec' = f (VS.slice blockOff blockSz vec)
+ in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec'))
+ | otherwise = RS.fromVector sh (f (RS.toVector arr))