aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Nested/Mixed.hs21
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs13
-rw-r--r--src/Data/Array/Nested/Ranked.hs4
-rw-r--r--src/Data/Array/Nested/Ranked/Base.hs6
-rw-r--r--src/Data/Array/Nested/Shaped.hs4
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs6
6 files changed, 31 insertions, 23 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index d658ed3..cad5c4e 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -369,11 +369,11 @@ class Elt a where
-- | Given the shape of this array, an index and a value, write the value at
-- that index in the vectors.
- mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
+ mvecsWrite :: Integral i => IShX sh -> IxX sh i -> a -> MixedVecs s sh a -> ST s ()
-- | Given the shape of this array, an index and a value, write the value at
-- that index in the vectors.
- mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+ mvecsWritePartial :: Integral i => IShX (sh ++ sh') -> IxX sh i -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
-- | Given the shape of this array, finalise the vectors into 'XArray's.
mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
@@ -468,8 +468,8 @@ instance Storable a => Elt (Primitive a) where
-- TODO: this use of toVector is suboptimal
mvecsWritePartial
- :: forall sh' sh s.
- IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
+ :: forall sh' sh s i. Integral i
+ => IShX (sh ++ sh') -> IxX sh i -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do
let arrsh = X.shape (ssxFromShX sh') arr
offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh))
@@ -678,8 +678,8 @@ instance Elt a => Elt (Mixed sh' a) where
mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs
- mvecsWritePartial :: forall sh1 sh2 s.
- IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
+ mvecsWritePartial :: forall sh1 sh2 s i. Integral i
+ => IShX (sh1 ++ sh2) -> IxX sh1 i -> Mixed sh2 (Mixed sh' a)
-> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
-> ST s ()
mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs)
@@ -728,8 +728,10 @@ msize = shxSize . mshape
-- the entire hierarchy (after distributing out tuples) must be a rectangular
-- array. The type of 'mgenerate' allows this requirement to be broken very
-- easily, hence the runtime check.
-mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a
-mgenerate sh f = case shxEnum sh of
+{-# INLINEABLE mgenerate #-}
+mgenerate :: forall sh a i. (KnownElt a, Integral i)
+ => IShX sh -> (IxX sh i -> a) -> Mixed sh a
+mgenerate sh f = case shxEnum' sh of
[] -> memptyArrayUnsafe sh
firstidx : restidxs ->
let firstelem = f (ixxZero' sh)
@@ -751,8 +753,7 @@ mgenerate sh f = case shxEnum sh of
-- | An optimized special case of `mgenerate', where the function results
-- are of a primitive type and so there's not need to verify the shapes
--- of them all are equal. This is also generalized to aribitrary @Num@ index
--- type compared to @mgenerate@.
+-- of them all are equal.
{-# INLINE mgeneratePrim #-}
mgeneratePrim :: forall sh a i. (PrimElt a, Num i)
=> IShX sh -> (IxX sh i -> a) -> Mixed sh a
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index ed03310..f755703 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -231,11 +231,13 @@ ixxLength (IxX l) = listxLength l
ixxRank :: IxX sh i -> SNat (Rank sh)
ixxRank (IxX l) = listxRank l
-ixxZero :: StaticShX sh -> IIxX sh
+{-# INLINEABLE ixxZero #-}
+ixxZero :: Num i => StaticShX sh -> IxX sh i
ixxZero ZKX = ZIX
ixxZero (_ :!% ssh) = 0 :.% ixxZero ssh
-ixxZero' :: IShX sh -> IIxX sh
+{-# INLINEABLE ixxZero' #-}
+ixxZero' :: Num i => IShX sh -> IxX sh i
ixxZero' ZSX = ZIX
ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh
@@ -301,15 +303,16 @@ ixxFromLinear = \sh -> -- give this function arity 1 so that suffixes is shared
outrange sh i = error $ "ixxFromLinear: out of range (" ++ show i ++
" in array of shape " ++ show sh ++ ")"
-ixxToLinear :: IShX sh -> IIxX sh -> Int
+{-# INLINEABLE ixxToLinear #-}
+ixxToLinear :: Integral i => IShX sh -> IxX sh i -> Int
ixxToLinear = \sh i -> fst (go sh i)
where
-- returns (index in subarray, size of subarray)
- go :: IShX sh -> IIxX sh -> (Int, Int)
+ go :: Integral i => IShX sh -> IxX sh i -> (Int, Int)
go ZSX ZIX = (0, 1)
go (n :$% sh) (i :.% ix) =
let (lidx, sz) = go sh ix
- in (sz * i + lidx, fromSMayNat' n * sz)
+ in (sz * fromIntegral i + lidx, fromSMayNat' n * sz)
-- * Mixed shapes
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
index 37925fb..22ca117 100644
--- a/src/Data/Array/Nested/Ranked.hs
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -62,7 +62,9 @@ rindexPartial (Ranked arr) idx =
-- | __WARNING__: All values returned from the function must have equal shape.
-- See the documentation of 'mgenerate' for more details.
-rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a
+{-# INLINEABLE rgenerate #-}
+rgenerate :: forall n a i. (KnownElt a, Integral i)
+ => IShR n -> (IxR n i -> a) -> Ranked n a
rgenerate sh f
| sn@SNat <- shrRank sh
, Dict <- lemKnownReplicate sn
diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs
index 11a8ffb..04f2ea2 100644
--- a/src/Data/Array/Nested/Ranked/Base.hs
+++ b/src/Data/Array/Nested/Ranked/Base.hs
@@ -149,14 +149,14 @@ instance Elt a => Elt (Ranked n a) where
marrayStrides (M_Ranked arr) = marrayStrides arr
- mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
+ mvecsWrite :: forall sh s i. Integral i => IShX sh -> IxX sh i -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
mvecsWrite sh idx (Ranked arr) vecs =
mvecsWrite sh idx arr
(coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
vecs)
- mvecsWritePartial :: forall sh sh' s.
- IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
+ mvecsWritePartial :: forall sh sh' s i. Integral i
+ => IShX (sh ++ sh') -> IxX sh i -> Mixed sh' (Ranked n a)
-> MixedVecs s (sh ++ sh') (Ranked n a)
-> ST s ()
mvecsWritePartial sh idx arr vecs =
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index 075549d..e842c18 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -69,7 +69,9 @@ sindexPartial sarr@(Shaped arr) idx =
-- | __WARNING__: All values returned from the function must have equal shape.
-- See the documentation of 'mgenerate' for more details.
-sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a
+{-# INLINEABLE sgenerate #-}
+sgenerate :: forall sh a i. (KnownElt a, Integral i)
+ => ShS sh -> (IxS sh i -> a) -> Shaped sh a
sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh))
{-# INLINE sgeneratePrim #-}
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
index 98f1241..342a296 100644
--- a/src/Data/Array/Nested/Shaped/Base.hs
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -142,14 +142,14 @@ instance Elt a => Elt (Shaped sh a) where
marrayStrides (M_Shaped arr) = marrayStrides arr
- mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
+ mvecsWrite :: forall sh' s i. Integral i => IShX sh' -> IxX sh' i -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
mvecsWrite sh idx (Shaped arr) vecs =
mvecsWrite sh idx arr
(coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
vecs)
- mvecsWritePartial :: forall sh1 sh2 s.
- IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
+ mvecsWritePartial :: forall sh1 sh2 s i. Integral i
+ => IShX (sh1 ++ sh2) -> IxX sh1 i -> Mixed sh2 (Shaped sh a)
-> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
-> ST s ()
mvecsWritePartial sh idx arr vecs =