diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 10 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 18 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Base.hs | 8 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 15 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 8 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 12 |
6 files changed, 58 insertions, 13 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index fc1c108..6b96a15 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -378,6 +378,9 @@ class Elt a where -- | Given the shape of this array, finalise the vectors into 'XArray's. mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) + -- | Given the shape of this array, finalise the vectors into 'XArray's. + mvecsUnsafeFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) + -- | Element types for which we have evidence of the (static part of the) shape -- in a type class constraint. Compare the instance contexts of the instances -- of this class with those of 'Elt': some instances have an additional @@ -479,6 +482,7 @@ instance Storable a => Elt (Primitive a) where VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v + mvecsUnsafeFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.unsafeFreeze v -- [PRIMITIVE ELEMENT TYPES LIST] deriving via Primitive Bool instance Elt Bool @@ -553,6 +557,7 @@ instance (Elt a, Elt b) => Elt (a, b) where mvecsWritePartialLinear proxy i x a mvecsWritePartialLinear proxy i y b mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b + mvecsUnsafeFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsUnsafeFreeze sh a <*> mvecsUnsafeFreeze sh b instance (KnownElt a, KnownElt b) => KnownElt (a, b) where memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh) @@ -694,6 +699,7 @@ instance Elt a => Elt (Mixed sh' a) where = mvecsWritePartialLinear proxy idx arr vecs mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs + mvecsUnsafeFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsUnsafeFreeze (shxAppend sh sh') vecs instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where memptyArrayUnsafe sh = M_Nest sh (memptyArrayUnsafe (shxAppend sh (shxCompleteZeros (knownShX @sh')))) @@ -706,7 +712,7 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where mvecsReplicate sh example = do vecs <- mvecsUnsafeNew sh example - forM_ (shxEnum sh) $ \idx -> mvecsWrite sh idx example vecs + forM_ [0 .. shxSize sh - 1] $ \idx -> mvecsWriteLinear idx example vecs -- this is a slow case, but the alternative, mvecsUnsafeNew with manual -- writing in a loop, leads to every case being as slow return vecs @@ -772,7 +778,7 @@ mgenerate sh f = case shxEnum sh of when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ error "Data.Array.Nested mgenerate: generated values do not have equal shapes" mvecsWrite sh idx val vecs - mvecsFreeze sh vecs + mvecsUnsafeFreeze sh vecs -- | An optimized special case of 'mgenerate', where the function results -- are of a primitive type and so there's not need to check that all shapes diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index c999853..11ef757 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -36,7 +36,7 @@ import Data.Functor.Product import Data.Kind (Constraint, Type) import Data.Monoid (Sum(..)) import Data.Type.Equality -import GHC.Exts (Int(..), Int#, quotRemInt#, withDict, build) +import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict) import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList @@ -284,15 +284,15 @@ ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k ixxZipWith _ ZIX ZIX = ZIX ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js -ixxToLinear :: IShX sh -> IIxX sh -> Int -ixxToLinear = \sh i -> fst (go sh i) +-- | Given a multidimensional index, get the corresponding linear +-- index into the buffer. +{-# INLINEABLE ixxToLinear #-} +ixxToLinear :: Num i => IShX sh -> IxX sh i -> i +ixxToLinear = \sh i -> go sh i 0 where - -- returns (index in subarray, size of subarray) - go :: IShX sh -> IIxX sh -> (Int, Int) - go ZSX ZIX = (0, 1) - go (n :$% sh) (i :.% ix) = - let (lidx, sz) = go sh ix - in (sz * i + lidx, fromSMayNat' n * sz) + go :: Num i => IShX sh -> IxX sh i -> i -> i + go ZSX ZIX a = a + go (n :$% sh) (i :.% ix) a = go sh ix (fromIntegral (fromSMayNat' n) * a + i) -- * Mixed shapes diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs index ed194a8..97a5f6f 100644 --- a/src/Data/Array/Nested/Ranked/Base.hs +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -177,6 +177,14 @@ instance Elt a => Elt (Ranked n a) where (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) vecs) + mvecsUnsafeFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) + mvecsUnsafeFreeze sh vecs = + coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) + @(Mixed sh (Ranked n a)) + <$> mvecsUnsafeFreeze sh + (coerce @(MixedVecs s sh (Ranked n a)) + @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a) diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index 6d61bd5..b6bee2e 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -36,7 +36,7 @@ import Data.Foldable qualified as Foldable import Data.Kind (Type) import Data.Proxy import Data.Type.Equality -import GHC.Exts (Int(..), Int#, quotRemInt#, build) +import GHC.Exts (Int(..), Int#, build, quotRemInt#) import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList @@ -291,6 +291,19 @@ ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i ixrPermutePrefix = coerce (listrPermutePrefix @i) +-- | Given a multidimensional index, get the corresponding linear +-- index into the buffer. +{-# INLINEABLE ixrToLinear #-} +ixrToLinear :: Num i => IShR m -> IxR m i -> i +ixrToLinear = \sh i -> go sh i 0 + where + -- Additional argument: index, in the @m - m1@ dimensional array so far, + -- of the @m - m1 + n@ dimensional tensor pointed to by the current + -- @m - m1@ dimensional index prefix. + go :: Num i => IShR m1 -> IxR m1 i -> i -> i + go ZSR ZIR a = a + go (n :$: sh) (i :.: ix) a = go sh ix (fromIntegral n * a + i) + -- * Ranked shapes diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index e5dd852..e2ec416 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -170,6 +170,14 @@ instance Elt a => Elt (Shaped sh a) where (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) vecs) + mvecsUnsafeFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) + mvecsUnsafeFreeze sh vecs = + coerce @(Mixed sh' (Mixed (MapJust sh) a)) + @(Mixed sh' (Shaped sh a)) + <$> mvecsUnsafeFreeze sh + (coerce @(MixedVecs s sh' (Shaped sh a)) + @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index 0d90e91..f616946 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -38,7 +38,7 @@ import Data.Kind (Constraint, Type) import Data.Monoid (Sum(..)) import Data.Proxy import Data.Type.Equality -import GHC.Exts (Int(..), Int#, quotRemInt#, withDict, build) +import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict) import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList @@ -301,6 +301,16 @@ ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) +-- | Given a multidimensional index, get the corresponding linear +-- index into the buffer. +{-# INLINEABLE ixsToLinear #-} +ixsToLinear :: Num i => ShS sh -> IxS sh i -> i +ixsToLinear = \sh i -> go sh i 0 + where + go :: Num i => ShS sh -> IxS sh i -> i -> i + go ZSS ZIS a = a + go (n :$$ sh) (i :.$ ix) a = go sh ix (fromIntegral (fromSNat' n) * a + i) + -- * Shaped shapes |
