aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Data/Array/Nested/Mixed.hs8
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs8
-rw-r--r--src/Data/Array/Nested/Ranked.hs4
-rw-r--r--src/Data/Array/Nested/Shaped.hs4
4 files changed, 15 insertions, 9 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 515e867..d658ed3 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -751,9 +751,11 @@ mgenerate sh f = case shxEnum sh of
-- | An optimized special case of `mgenerate', where the function results
-- are of a primitive type and so there's not need to verify the shapes
--- of them all are equal.
-mgeneratePrim :: forall sh a. PrimElt a
- => IShX sh -> (IIxX sh -> a) -> Mixed sh a
+-- of them all are equal. This is also generalized to aribitrary @Num@ index
+-- type compared to @mgenerate@.
+{-# INLINE mgeneratePrim #-}
+mgeneratePrim :: forall sh a i. (PrimElt a, Num i)
+ => IShX sh -> (IxX sh i -> a) -> Mixed sh a
mgeneratePrim sh f =
let g i = f (ixxFromLinear sh i)
in mfromVector sh $ VS.generate (shxSize sh) g
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index 066ae8e..ed03310 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -275,7 +275,7 @@ ixxZipWith _ ZIX ZIX = ZIX
ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js
{-# INLINEABLE ixxFromLinear #-}
-ixxFromLinear :: IShX sh -> Int -> IIxX sh
+ixxFromLinear :: Num i => IShX sh -> Int -> IxX sh i
ixxFromLinear = \sh -> -- give this function arity 1 so that suffixes is shared when it's called many times
let suffixes = drop 1 (scanr (*) 1 (shxToList sh))
in \i ->
@@ -286,14 +286,14 @@ ixxFromLinear = \sh -> -- give this function arity 1 so that suffixes is shared
(n :$% sh', suff : suffs) ->
let (q, r) = i `quotRem` suff
in if q >= fromSMayNat' n then outrange sh i else
- q :.% fromLin sh' suffs r
+ fromIntegral q :.% fromLin sh' suffs r
_ -> error "impossible"
where
- fromLin :: IShX sh -> [Int] -> Int -> IxX sh Int
+ fromLin :: Num i => IShX sh -> [Int] -> Int -> IxX sh i
fromLin ZSX _ !_ = ZIX
fromLin (_ :$% sh') (suff : suffs) i =
let (q, r) = i `quotRem` suff -- suff == shrSize sh'
- in q :.% fromLin sh' suffs r
+ in fromIntegral q :.% fromLin sh' suffs r
fromLin _ _ _ = error "impossible"
{-# NOINLINE outrange #-}
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
index 9504247..37925fb 100644
--- a/src/Data/Array/Nested/Ranked.hs
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -71,7 +71,9 @@ rgenerate sh f
-- TODO: this would be shorter and faster written with rfromVector,
-- but unfortunately we don't have ixrFromLinear
-rgeneratePrim :: forall n a. PrimElt a => IShR n -> (IIxR n -> a) -> Ranked n a
+{-# INLINE rgeneratePrim #-}
+rgeneratePrim :: forall n a i. (PrimElt a, Num i)
+ => IShR n -> (IxR n i -> a) -> Ranked n a
rgeneratePrim sh f
| sn@SNat <- shrRank sh
, Dict <- lemKnownReplicate sn
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index 31a7706..075549d 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -72,7 +72,9 @@ sindexPartial sarr@(Shaped arr) idx =
sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a
sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh))
-sgeneratePrim :: forall sh a. PrimElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a
+{-# INLINE sgeneratePrim #-}
+sgeneratePrim :: forall sh a i. (PrimElt a, Num i)
+ => ShS sh -> (IxS sh i -> a) -> Shaped sh a
sgeneratePrim sh f = Shaped (mgeneratePrim (shxFromShS sh) (f . ixsFromIxX sh))
-- | See the documentation of 'mlift'.