diff options
Diffstat (limited to 'src/Data/Array/Nested')
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 21 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 13 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked.hs | 4 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Base.hs | 6 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 4 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 6 |
6 files changed, 31 insertions, 23 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index d658ed3..cad5c4e 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -369,11 +369,11 @@ class Elt a where -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. - mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () + mvecsWrite :: Integral i => IShX sh -> IxX sh i -> a -> MixedVecs s sh a -> ST s () -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. - mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + mvecsWritePartial :: Integral i => IShX (sh ++ sh') -> IxX sh i -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () -- | Given the shape of this array, finalise the vectors into 'XArray's. mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) @@ -468,8 +468,8 @@ instance Storable a => Elt (Primitive a) where -- TODO: this use of toVector is suboptimal mvecsWritePartial - :: forall sh' sh s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () + :: forall sh' sh s i. Integral i + => IShX (sh ++ sh') -> IxX sh i -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do let arrsh = X.shape (ssxFromShX sh') arr offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) @@ -678,8 +678,8 @@ instance Elt a => Elt (Mixed sh' a) where mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) + mvecsWritePartial :: forall sh1 sh2 s i. Integral i + => IShX (sh1 ++ sh2) -> IxX sh1 i -> Mixed sh2 (Mixed sh' a) -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) -> ST s () mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) @@ -728,8 +728,10 @@ msize = shxSize . mshape -- the entire hierarchy (after distributing out tuples) must be a rectangular -- array. The type of 'mgenerate' allows this requirement to be broken very -- easily, hence the runtime check. -mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a -mgenerate sh f = case shxEnum sh of +{-# INLINEABLE mgenerate #-} +mgenerate :: forall sh a i. (KnownElt a, Integral i) + => IShX sh -> (IxX sh i -> a) -> Mixed sh a +mgenerate sh f = case shxEnum' sh of [] -> memptyArrayUnsafe sh firstidx : restidxs -> let firstelem = f (ixxZero' sh) @@ -751,8 +753,7 @@ mgenerate sh f = case shxEnum sh of -- | An optimized special case of `mgenerate', where the function results -- are of a primitive type and so there's not need to verify the shapes --- of them all are equal. This is also generalized to aribitrary @Num@ index --- type compared to @mgenerate@. +-- of them all are equal. {-# INLINE mgeneratePrim #-} mgeneratePrim :: forall sh a i. (PrimElt a, Num i) => IShX sh -> (IxX sh i -> a) -> Mixed sh a diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index ed03310..f755703 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -231,11 +231,13 @@ ixxLength (IxX l) = listxLength l ixxRank :: IxX sh i -> SNat (Rank sh) ixxRank (IxX l) = listxRank l -ixxZero :: StaticShX sh -> IIxX sh +{-# INLINEABLE ixxZero #-} +ixxZero :: Num i => StaticShX sh -> IxX sh i ixxZero ZKX = ZIX ixxZero (_ :!% ssh) = 0 :.% ixxZero ssh -ixxZero' :: IShX sh -> IIxX sh +{-# INLINEABLE ixxZero' #-} +ixxZero' :: Num i => IShX sh -> IxX sh i ixxZero' ZSX = ZIX ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh @@ -301,15 +303,16 @@ ixxFromLinear = \sh -> -- give this function arity 1 so that suffixes is shared outrange sh i = error $ "ixxFromLinear: out of range (" ++ show i ++ " in array of shape " ++ show sh ++ ")" -ixxToLinear :: IShX sh -> IIxX sh -> Int +{-# INLINEABLE ixxToLinear #-} +ixxToLinear :: Integral i => IShX sh -> IxX sh i -> Int ixxToLinear = \sh i -> fst (go sh i) where -- returns (index in subarray, size of subarray) - go :: IShX sh -> IIxX sh -> (Int, Int) + go :: Integral i => IShX sh -> IxX sh i -> (Int, Int) go ZSX ZIX = (0, 1) go (n :$% sh) (i :.% ix) = let (lidx, sz) = go sh ix - in (sz * i + lidx, fromSMayNat' n * sz) + in (sz * fromIntegral i + lidx, fromSMayNat' n * sz) -- * Mixed shapes diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs index 37925fb..22ca117 100644 --- a/src/Data/Array/Nested/Ranked.hs +++ b/src/Data/Array/Nested/Ranked.hs @@ -62,7 +62,9 @@ rindexPartial (Ranked arr) idx = -- | __WARNING__: All values returned from the function must have equal shape. -- See the documentation of 'mgenerate' for more details. -rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a +{-# INLINEABLE rgenerate #-} +rgenerate :: forall n a i. (KnownElt a, Integral i) + => IShR n -> (IxR n i -> a) -> Ranked n a rgenerate sh f | sn@SNat <- shrRank sh , Dict <- lemKnownReplicate sn diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs index 11a8ffb..04f2ea2 100644 --- a/src/Data/Array/Nested/Ranked/Base.hs +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -149,14 +149,14 @@ instance Elt a => Elt (Ranked n a) where marrayStrides (M_Ranked arr) = marrayStrides arr - mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () + mvecsWrite :: forall sh s i. Integral i => IShX sh -> IxX sh i -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () mvecsWrite sh idx (Ranked arr) vecs = mvecsWrite sh idx arr (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) vecs) - mvecsWritePartial :: forall sh sh' s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) + mvecsWritePartial :: forall sh sh' s i. Integral i + => IShX (sh ++ sh') -> IxX sh i -> Mixed sh' (Ranked n a) -> MixedVecs s (sh ++ sh') (Ranked n a) -> ST s () mvecsWritePartial sh idx arr vecs = diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 075549d..e842c18 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -69,7 +69,9 @@ sindexPartial sarr@(Shaped arr) idx = -- | __WARNING__: All values returned from the function must have equal shape. -- See the documentation of 'mgenerate' for more details. -sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a +{-# INLINEABLE sgenerate #-} +sgenerate :: forall sh a i. (KnownElt a, Integral i) + => ShS sh -> (IxS sh i -> a) -> Shaped sh a sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh)) {-# INLINE sgeneratePrim #-} diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index 98f1241..342a296 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -142,14 +142,14 @@ instance Elt a => Elt (Shaped sh a) where marrayStrides (M_Shaped arr) = marrayStrides arr - mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () + mvecsWrite :: forall sh' s i. Integral i => IShX sh' -> IxX sh' i -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () mvecsWrite sh idx (Shaped arr) vecs = mvecsWrite sh idx arr (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) vecs) - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) + mvecsWritePartial :: forall sh1 sh2 s i. Integral i + => IShX (sh1 ++ sh2) -> IxX sh1 i -> Mixed sh2 (Shaped sh a) -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) -> ST s () mvecsWritePartial sh idx arr vecs = |
