diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-02 13:16:33 +0100 |
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-02 13:16:33 +0100 |
| commit | 9f47aa6a2bcd772388a5d5150ca7254e4eb95bc2 (patch) | |
| tree | 4c3a1b8a7b1a734e83f161f2b1be58ce4470cfa3 | |
| parent | ba5a31c976f80421464af1af8d6ab1e2a154cd83 (diff) | |
Generalize mgeneratePrim to potentially avoid @fmap fromIntegral@
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 8 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 8 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked.hs | 4 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 4 |
4 files changed, 15 insertions, 9 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 515e867..d658ed3 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -751,9 +751,11 @@ mgenerate sh f = case shxEnum sh of -- | An optimized special case of `mgenerate', where the function results -- are of a primitive type and so there's not need to verify the shapes --- of them all are equal. -mgeneratePrim :: forall sh a. PrimElt a - => IShX sh -> (IIxX sh -> a) -> Mixed sh a +-- of them all are equal. This is also generalized to aribitrary @Num@ index +-- type compared to @mgenerate@. +{-# INLINE mgeneratePrim #-} +mgeneratePrim :: forall sh a i. (PrimElt a, Num i) + => IShX sh -> (IxX sh i -> a) -> Mixed sh a mgeneratePrim sh f = let g i = f (ixxFromLinear sh i) in mfromVector sh $ VS.generate (shxSize sh) g diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index 066ae8e..ed03310 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -275,7 +275,7 @@ ixxZipWith _ ZIX ZIX = ZIX ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js {-# INLINEABLE ixxFromLinear #-} -ixxFromLinear :: IShX sh -> Int -> IIxX sh +ixxFromLinear :: Num i => IShX sh -> Int -> IxX sh i ixxFromLinear = \sh -> -- give this function arity 1 so that suffixes is shared when it's called many times let suffixes = drop 1 (scanr (*) 1 (shxToList sh)) in \i -> @@ -286,14 +286,14 @@ ixxFromLinear = \sh -> -- give this function arity 1 so that suffixes is shared (n :$% sh', suff : suffs) -> let (q, r) = i `quotRem` suff in if q >= fromSMayNat' n then outrange sh i else - q :.% fromLin sh' suffs r + fromIntegral q :.% fromLin sh' suffs r _ -> error "impossible" where - fromLin :: IShX sh -> [Int] -> Int -> IxX sh Int + fromLin :: Num i => IShX sh -> [Int] -> Int -> IxX sh i fromLin ZSX _ !_ = ZIX fromLin (_ :$% sh') (suff : suffs) i = let (q, r) = i `quotRem` suff -- suff == shrSize sh' - in q :.% fromLin sh' suffs r + in fromIntegral q :.% fromLin sh' suffs r fromLin _ _ _ = error "impossible" {-# NOINLINE outrange #-} diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs index 9504247..37925fb 100644 --- a/src/Data/Array/Nested/Ranked.hs +++ b/src/Data/Array/Nested/Ranked.hs @@ -71,7 +71,9 @@ rgenerate sh f -- TODO: this would be shorter and faster written with rfromVector, -- but unfortunately we don't have ixrFromLinear -rgeneratePrim :: forall n a. PrimElt a => IShR n -> (IIxR n -> a) -> Ranked n a +{-# INLINE rgeneratePrim #-} +rgeneratePrim :: forall n a i. (PrimElt a, Num i) + => IShR n -> (IxR n i -> a) -> Ranked n a rgeneratePrim sh f | sn@SNat <- shrRank sh , Dict <- lemKnownReplicate sn diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 31a7706..075549d 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -72,7 +72,9 @@ sindexPartial sarr@(Shaped arr) idx = sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh)) -sgeneratePrim :: forall sh a. PrimElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a +{-# INLINE sgeneratePrim #-} +sgeneratePrim :: forall sh a i. (PrimElt a, Num i) + => ShS sh -> (IxS sh i -> a) -> Shaped sh a sgeneratePrim sh f = Shaped (mgeneratePrim (shxFromShS sh) (f . ixsFromIxX sh)) -- | See the documentation of 'mlift'. |
