aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Ranked/Shape.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Ranked/Shape.hs')
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs15
1 files changed, 14 insertions, 1 deletions
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index 6d61bd5..b6bee2e 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -36,7 +36,7 @@ import Data.Foldable qualified as Foldable
import Data.Kind (Type)
import Data.Proxy
import Data.Type.Equality
-import GHC.Exts (Int(..), Int#, quotRemInt#, build)
+import GHC.Exts (Int(..), Int#, build, quotRemInt#)
import GHC.Generics (Generic)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
@@ -291,6 +291,19 @@ ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2
ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
ixrPermutePrefix = coerce (listrPermutePrefix @i)
+-- | Given a multidimensional index, get the corresponding linear
+-- index into the buffer.
+{-# INLINEABLE ixrToLinear #-}
+ixrToLinear :: Num i => IShR m -> IxR m i -> i
+ixrToLinear = \sh i -> go sh i 0
+ where
+ -- Additional argument: index, in the @m - m1@ dimensional array so far,
+ -- of the @m - m1 + n@ dimensional tensor pointed to by the current
+ -- @m - m1@ dimensional index prefix.
+ go :: Num i => IShR m1 -> IxR m1 i -> i -> i
+ go ZSR ZIR a = a
+ go (n :$: sh) (i :.: ix) a = go sh ix (fromIntegral n * a + i)
+
-- * Ranked shapes