aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
-rw-r--r--src/Data/Array/Nested/Mixed.hs19
1 files changed, 19 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 2b5c5b6..ffbc993 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -408,6 +408,9 @@ class Elt a => KnownElt a where
-- Arrays of scalars are basically just arrays of scalars.
instance Storable a => Elt (Primitive a) where
+ -- Somehow, INLINE here can increase allocation with GHC 9.14.1.
+ -- Maybe that happens in void instances such as @Primitive ()@.
+ {-# INLINEABLE mshape #-}
mshape (M_Primitive sh _) = sh
{-# INLINEABLE mindex #-}
mindex (M_Primitive _ a) i = Primitive (X.index a i)
@@ -523,8 +526,11 @@ deriving via Primitive () instance KnownElt ()
-- Arrays of pairs are pairs of arrays.
instance (Elt a, Elt b) => Elt (a, b) where
+ {-# INLINEABLE mshape #-}
mshape (M_Tup2 a _) = mshape a
+ {-# INLINEABLE mindex #-}
mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
+ {-# INLINEABLE mindexPartial #-}
mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
mfromListOuterSN sn l =
@@ -580,13 +586,16 @@ instance Elt a => Elt (Mixed sh' a) where
-- TODO: this is quadratic in the nesting depth because it repeatedly
-- truncates the shape vector to one a little shorter. Fix with a
-- moverlongShape method, a prefix of which is mshape.
+ {-# INLINEABLE mshape #-}
mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh
mshape (M_Nest sh arr)
= fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr))
+ {-# INLINEABLE mindex #-}
mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a
mindex (M_Nest _ arr) = mindexPartial arr
+ {-# INLINEABLE mindexPartial #-}
mindexPartial :: forall sh1 sh2.
Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
mindexPartial (M_Nest sh arr) i
@@ -804,19 +813,23 @@ mgeneratePrim sh f =
let g i = f (ixxFromLinear sh i)
in mfromVector sh $ VS.generate (shxSize sh) g
+{-# INLINEABLE msumOuter1PrimP #-}
msumOuter1PrimP :: forall sh n a. (Storable a, NumElt a)
=> Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
msumOuter1PrimP (M_Primitive (n :$% sh) arr) =
let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX
in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr)
+{-# INLINEABLE msumOuter1Prim #-}
msumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a)
=> Mixed (n : sh) a -> Mixed sh a
msumOuter1Prim = fromPrimitive . msumOuter1PrimP @sh @n @a . toPrimitive
+{-# INLINEABLE msumAllPrimP #-}
msumAllPrimP :: (Storable a, NumElt a) => Mixed sh (Primitive a) -> a
msumAllPrimP (M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr
+{-# INLINEABLE msumAllPrim #-}
msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a
msumAllPrim arr = msumAllPrimP (toPrimitive arr)
@@ -837,15 +850,19 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
=> StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b
f ssh' = X.append (ssxAppend ssh ssh')
+{-# INLINEABLE mfromVectorP #-}
mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a)
mfromVectorP sh v = M_Primitive sh (X.fromVector sh v)
+{-# INLINEABLE mfromVector #-}
mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a
mfromVector sh v = fromPrimitive (mfromVectorP sh v)
+{-# INLINEABLE mtoVectorP #-}
mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a
mtoVectorP (M_Primitive _ v) = X.toVector v
+{-# INLINEABLE mtoVector #-}
mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a
mtoVector arr = mtoVectorP (toPrimitive arr)
@@ -1044,6 +1061,7 @@ mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
ixxFromList (ssxFromShX sh) (numEltMaxIndex (shxRank sh) (fromO arr))
+{-# INLINEABLE mdot1Inner #-}
mdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
=> Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a
mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b))
@@ -1059,6 +1077,7 @@ mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primi
-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
-- Prefer 'mdot1Inner' if applicable.
+{-# INLINEABLE mdot #-}
mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a
mdot a b =
munScalar $