diff options
Diffstat (limited to 'src/Data/Array/Nested/Ranked')
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Base.hs | 8 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 15 |
2 files changed, 22 insertions, 1 deletions
diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs index ed194a8..97a5f6f 100644 --- a/src/Data/Array/Nested/Ranked/Base.hs +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -177,6 +177,14 @@ instance Elt a => Elt (Ranked n a) where (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) vecs) + mvecsUnsafeFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) + mvecsUnsafeFreeze sh vecs = + coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) + @(Mixed sh (Ranked n a)) + <$> mvecsUnsafeFreeze sh + (coerce @(MixedVecs s sh (Ranked n a)) + @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a) diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index 6d61bd5..b6bee2e 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -36,7 +36,7 @@ import Data.Foldable qualified as Foldable import Data.Kind (Type) import Data.Proxy import Data.Type.Equality -import GHC.Exts (Int(..), Int#, quotRemInt#, build) +import GHC.Exts (Int(..), Int#, build, quotRemInt#) import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList @@ -291,6 +291,19 @@ ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i ixrPermutePrefix = coerce (listrPermutePrefix @i) +-- | Given a multidimensional index, get the corresponding linear +-- index into the buffer. +{-# INLINEABLE ixrToLinear #-} +ixrToLinear :: Num i => IShR m -> IxR m i -> i +ixrToLinear = \sh i -> go sh i 0 + where + -- Additional argument: index, in the @m - m1@ dimensional array so far, + -- of the @m - m1 + n@ dimensional tensor pointed to by the current + -- @m - m1@ dimensional index prefix. + go :: Num i => IShR m1 -> IxR m1 i -> i -> i + go ZSR ZIR a = a + go (n :$: sh) (i :.: ix) a = go sh ix (fromIntegral n * a + i) + -- * Ranked shapes |
