diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-02 14:19:04 +0100 |
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-02 14:19:04 +0100 |
| commit | a7b64fe342524e82194d73af852b5f2f1bc5bab3 (patch) | |
| tree | 73fb2b9484d33e7400e92a67388227a75adf56c9 /src/Data/Array/Nested/Mixed.hs | |
| parent | 9f47aa6a2bcd772388a5d5150ca7254e4eb95bc2 (diff) | |
Generalize also mgenerate to potentially avoid @fmap fromIntegral@mgenerate-integral
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index d658ed3..cad5c4e 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -369,11 +369,11 @@ class Elt a where -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. - mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () + mvecsWrite :: Integral i => IShX sh -> IxX sh i -> a -> MixedVecs s sh a -> ST s () -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. - mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + mvecsWritePartial :: Integral i => IShX (sh ++ sh') -> IxX sh i -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () -- | Given the shape of this array, finalise the vectors into 'XArray's. mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) @@ -468,8 +468,8 @@ instance Storable a => Elt (Primitive a) where -- TODO: this use of toVector is suboptimal mvecsWritePartial - :: forall sh' sh s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () + :: forall sh' sh s i. Integral i + => IShX (sh ++ sh') -> IxX sh i -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do let arrsh = X.shape (ssxFromShX sh') arr offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) @@ -678,8 +678,8 @@ instance Elt a => Elt (Mixed sh' a) where mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) + mvecsWritePartial :: forall sh1 sh2 s i. Integral i + => IShX (sh1 ++ sh2) -> IxX sh1 i -> Mixed sh2 (Mixed sh' a) -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) -> ST s () mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) @@ -728,8 +728,10 @@ msize = shxSize . mshape -- the entire hierarchy (after distributing out tuples) must be a rectangular -- array. The type of 'mgenerate' allows this requirement to be broken very -- easily, hence the runtime check. -mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a -mgenerate sh f = case shxEnum sh of +{-# INLINEABLE mgenerate #-} +mgenerate :: forall sh a i. (KnownElt a, Integral i) + => IShX sh -> (IxX sh i -> a) -> Mixed sh a +mgenerate sh f = case shxEnum' sh of [] -> memptyArrayUnsafe sh firstidx : restidxs -> let firstelem = f (ixxZero' sh) @@ -751,8 +753,7 @@ mgenerate sh f = case shxEnum sh of -- | An optimized special case of `mgenerate', where the function results -- are of a primitive type and so there's not need to verify the shapes --- of them all are equal. This is also generalized to aribitrary @Num@ index --- type compared to @mgenerate@. +-- of them all are equal. {-# INLINE mgeneratePrim #-} mgeneratePrim :: forall sh a i. (PrimElt a, Num i) => IShX sh -> (IxX sh i -> a) -> Mixed sh a |
