aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
authorMikolaj Konarski <mikolaj.konarski@funktory.com>2025-12-09 10:34:03 +0100
committerMikolaj Konarski <mikolaj.konarski@funktory.com>2025-12-09 10:34:03 +0100
commitab020a0ece9383f04412964b9fc09d17874d3383 (patch)
treefd4593aa2eae379eb02c8dfba5a8481b92914fdb /src/Data/Array
parent3594dd9855efbcbfd8c1e62037e8c8a7ece93411 (diff)
Generalize ix?ToLinear and speed it up a bit
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs15
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs15
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs12
3 files changed, 32 insertions, 10 deletions
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index 145ea5f..11ef757 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -284,16 +284,15 @@ ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k
ixxZipWith _ ZIX ZIX = ZIX
ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js
+-- | Given a multidimensional index, get the corresponding linear
+-- index into the buffer.
{-# INLINEABLE ixxToLinear #-}
-ixxToLinear :: IShX sh -> IIxX sh -> Int
-ixxToLinear = \sh i -> fst (go sh i)
+ixxToLinear :: Num i => IShX sh -> IxX sh i -> i
+ixxToLinear = \sh i -> go sh i 0
where
- -- returns (index in subarray, size of subarray)
- go :: IShX sh -> IIxX sh -> (Int, Int)
- go ZSX ZIX = (0, 1)
- go (n :$% sh) (i :.% ix) =
- let (lidx, sz) = go sh ix
- in (sz * i + lidx, fromSMayNat' n * sz)
+ go :: Num i => IShX sh -> IxX sh i -> i -> i
+ go ZSX ZIX a = a
+ go (n :$% sh) (i :.% ix) a = go sh ix (fromIntegral (fromSMayNat' n) * a + i)
-- * Mixed shapes
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index 6d61bd5..b6bee2e 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -36,7 +36,7 @@ import Data.Foldable qualified as Foldable
import Data.Kind (Type)
import Data.Proxy
import Data.Type.Equality
-import GHC.Exts (Int(..), Int#, quotRemInt#, build)
+import GHC.Exts (Int(..), Int#, build, quotRemInt#)
import GHC.Generics (Generic)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
@@ -291,6 +291,19 @@ ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2
ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
ixrPermutePrefix = coerce (listrPermutePrefix @i)
+-- | Given a multidimensional index, get the corresponding linear
+-- index into the buffer.
+{-# INLINEABLE ixrToLinear #-}
+ixrToLinear :: Num i => IShR m -> IxR m i -> i
+ixrToLinear = \sh i -> go sh i 0
+ where
+ -- Additional argument: index, in the @m - m1@ dimensional array so far,
+ -- of the @m - m1 + n@ dimensional tensor pointed to by the current
+ -- @m - m1@ dimensional index prefix.
+ go :: Num i => IShR m1 -> IxR m1 i -> i -> i
+ go ZSR ZIR a = a
+ go (n :$: sh) (i :.: ix) a = go sh ix (fromIntegral n * a + i)
+
-- * Ranked shapes
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index 0d90e91..f616946 100644
--- a/src/Data/Array/Nested/Shaped/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -38,7 +38,7 @@ import Data.Kind (Constraint, Type)
import Data.Monoid (Sum(..))
import Data.Proxy
import Data.Type.Equality
-import GHC.Exts (Int(..), Int#, quotRemInt#, withDict, build)
+import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict)
import GHC.Generics (Generic)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
@@ -301,6 +301,16 @@ ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js
ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i
ixsPermutePrefix = coerce (listsPermutePrefix @(Const i))
+-- | Given a multidimensional index, get the corresponding linear
+-- index into the buffer.
+{-# INLINEABLE ixsToLinear #-}
+ixsToLinear :: Num i => ShS sh -> IxS sh i -> i
+ixsToLinear = \sh i -> go sh i 0
+ where
+ go :: Num i => ShS sh -> IxS sh i -> i -> i
+ go ZSS ZIS a = a
+ go (n :$$ sh) (i :.$ ix) a = go sh ix (fromIntegral (fromSNat' n) * a + i)
+
-- * Shaped shapes