aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMikolaj Konarski <mikolaj.konarski@funktory.com>2025-12-01 01:44:05 +0100
committerMikolaj Konarski <mikolaj.konarski@funktory.com>2025-12-01 16:06:57 +0100
commit9560d0f26420409afd2230fb7e5e111eafcced06 (patch)
tree5f64f9adfc521143bdf86be3128364aee4679c55
parenta06c6416bab1639e5c3bd99b3c10de4dcf6c32f9 (diff)
Expose the unfolding of the indexing operations
-rw-r--r--src/Data/Array/Nested/Mixed.hs2
-rw-r--r--src/Data/Array/Nested/Ranked.hs2
-rw-r--r--src/Data/Array/Nested/Shaped.hs2
-rw-r--r--src/Data/Array/XArray.hs10
-rw-r--r--src/Data/Vector/Generic/Checked.hs11
5 files changed, 22 insertions, 5 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 6d4ffd6..e3aa7a1 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -399,7 +399,9 @@ class Elt a => KnownElt a where
-- Arrays of scalars are basically just arrays of scalars.
instance Storable a => Elt (Primitive a) where
mshape (M_Primitive sh _) = sh
+ {-# INLINEABLE mindex #-}
mindex (M_Primitive _ a) i = Primitive (X.index a i)
+ {-# INLINEABLE mindexPartial #-}
mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i)
mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
mfromListOuterSN sn l@(arr1 :| _) =
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
index 2fbfdd8..bf35cc4 100644
--- a/src/Data/Array/Nested/Ranked.hs
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -49,9 +49,11 @@ remptyArray = mtoRanked (memptyArray ZSX)
rsize :: Elt a => Ranked n a -> Int
rsize = shrSize . rshape
+{-# INLINEABLE rindex #-}
rindex :: Elt a => Ranked n a -> IIxR n -> a
rindex (Ranked arr) idx = mindex arr (ixxFromIxR idx)
+{-# INLINEABLE rindexPartial #-}
rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a
rindexPartial (Ranked arr) idx =
Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index 8957549..82dfc91 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -52,6 +52,7 @@ srank = shsRank . sshape
ssize :: Elt a => Shaped sh a -> Int
ssize = shsSize . sshape
+{-# INLINEABLE sindex #-}
sindex :: Elt a => Shaped sh a -> IIxS sh -> a
sindex (Shaped arr) idx = mindex arr (ixxFromIxS idx)
@@ -59,6 +60,7 @@ shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh
shsTakeIx _ _ ZIS = ZSS
shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx
+{-# INLINEABLE sindexPartial #-}
sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a
sindexPartial sarr@(Shaped arr) idx =
Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs
index 0f87168..ee83654 100644
--- a/src/Data/Array/XArray.hs
+++ b/src/Data/Array/XArray.hs
@@ -114,10 +114,20 @@ generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh)
-- XArray . S.fromVector (shxShapeL sh)
-- <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh)
+{-# INLINEABLE indexPartial #-}
indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a
indexPartial (XArray arr) ZIX = XArray arr
indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx
+{- Strangely, this increases allocation and there's no noticeable speedup:
+indexPartial (XArray (ORS.A (ORG.A sh t))) ix =
+ let linear = OI.offset t + sum (zipWith (*) (ixxToList ix) (OI.strides t))
+ len = ixxLength ix
+ in XArray (ORS.A (ORG.A (drop len sh)
+ OI.T{ OI.strides = drop len (OI.strides t)
+ , OI.offset = linear
+ , OI.values = OI.values t })) -}
+{-# INLINEABLE index #-}
index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a
index (XArray (ORS.A (ORG.A _ t))) i =
OI.values t VS.! (OI.offset t + sum (zipWith (*) (toList i) (OI.strides t)))
diff --git a/src/Data/Vector/Generic/Checked.hs b/src/Data/Vector/Generic/Checked.hs
index d173bbf..d8aaaae 100644
--- a/src/Data/Vector/Generic/Checked.hs
+++ b/src/Data/Vector/Generic/Checked.hs
@@ -1,13 +1,14 @@
{-# LANGUAGE CPP #-}
+{-# LANGUAGE ImportQualifiedPost #-}
module Data.Vector.Generic.Checked (
fromListNChecked,
) where
-import qualified Data.Stream.Monadic as Stream
-import qualified Data.Vector.Fusion.Bundle.Monadic as VBM
-import qualified Data.Vector.Fusion.Bundle.Size as VBS
-import qualified Data.Vector.Fusion.Util as VFU
-import qualified Data.Vector.Generic as VG
+import Data.Stream.Monadic qualified as Stream
+import Data.Vector.Fusion.Bundle.Monadic qualified as VBM
+import Data.Vector.Fusion.Bundle.Size qualified as VBS
+import Data.Vector.Fusion.Util qualified as VFU
+import Data.Vector.Generic qualified as VG
-- for INLINE_FUSED and INLINE_INNER
#include "vector.h"