diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-09 10:34:03 +0100 |
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-09 10:34:03 +0100 |
| commit | ab020a0ece9383f04412964b9fc09d17874d3383 (patch) | |
| tree | fd4593aa2eae379eb02c8dfba5a8481b92914fdb /src/Data/Array | |
| parent | 3594dd9855efbcbfd8c1e62037e8c8a7ece93411 (diff) | |
Generalize ix?ToLinear and speed it up a bit
Diffstat (limited to 'src/Data/Array')
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 15 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 15 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 12 |
3 files changed, 32 insertions, 10 deletions
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index 145ea5f..11ef757 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -284,16 +284,15 @@ ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k ixxZipWith _ ZIX ZIX = ZIX ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js +-- | Given a multidimensional index, get the corresponding linear +-- index into the buffer. {-# INLINEABLE ixxToLinear #-} -ixxToLinear :: IShX sh -> IIxX sh -> Int -ixxToLinear = \sh i -> fst (go sh i) +ixxToLinear :: Num i => IShX sh -> IxX sh i -> i +ixxToLinear = \sh i -> go sh i 0 where - -- returns (index in subarray, size of subarray) - go :: IShX sh -> IIxX sh -> (Int, Int) - go ZSX ZIX = (0, 1) - go (n :$% sh) (i :.% ix) = - let (lidx, sz) = go sh ix - in (sz * i + lidx, fromSMayNat' n * sz) + go :: Num i => IShX sh -> IxX sh i -> i -> i + go ZSX ZIX a = a + go (n :$% sh) (i :.% ix) a = go sh ix (fromIntegral (fromSMayNat' n) * a + i) -- * Mixed shapes diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index 6d61bd5..b6bee2e 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -36,7 +36,7 @@ import Data.Foldable qualified as Foldable import Data.Kind (Type) import Data.Proxy import Data.Type.Equality -import GHC.Exts (Int(..), Int#, quotRemInt#, build) +import GHC.Exts (Int(..), Int#, build, quotRemInt#) import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList @@ -291,6 +291,19 @@ ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i ixrPermutePrefix = coerce (listrPermutePrefix @i) +-- | Given a multidimensional index, get the corresponding linear +-- index into the buffer. +{-# INLINEABLE ixrToLinear #-} +ixrToLinear :: Num i => IShR m -> IxR m i -> i +ixrToLinear = \sh i -> go sh i 0 + where + -- Additional argument: index, in the @m - m1@ dimensional array so far, + -- of the @m - m1 + n@ dimensional tensor pointed to by the current + -- @m - m1@ dimensional index prefix. + go :: Num i => IShR m1 -> IxR m1 i -> i -> i + go ZSR ZIR a = a + go (n :$: sh) (i :.: ix) a = go sh ix (fromIntegral n * a + i) + -- * Ranked shapes diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index 0d90e91..f616946 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -38,7 +38,7 @@ import Data.Kind (Constraint, Type) import Data.Monoid (Sum(..)) import Data.Proxy import Data.Type.Equality -import GHC.Exts (Int(..), Int#, quotRemInt#, withDict, build) +import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict) import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList @@ -301,6 +301,16 @@ ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) +-- | Given a multidimensional index, get the corresponding linear +-- index into the buffer. +{-# INLINEABLE ixsToLinear #-} +ixsToLinear :: Num i => ShS sh -> IxS sh i -> i +ixsToLinear = \sh i -> go sh i 0 + where + go :: Num i => ShS sh -> IxS sh i -> i -> i + go ZSS ZIS a = a + go (n :$$ sh) (i :.$ ix) a = go sh ix (fromIntegral (fromSNat' n) * a + i) + -- * Shaped shapes |
