From 2cf2817f321f705cb0d97d2188c17067915507ea Mon Sep 17 00:00:00 2001 From: Mikolaj Konarski Date: Wed, 24 Dec 2025 19:31:38 +0100 Subject: Inline most lifting wrappers This results in only marginal performance gain, probably because they are already small enough to be specialized and/or inlined automatically, but these pragmas ensure it remains so regardless of changes in GHC heuristics. --- src/Data/Array/Nested/Mixed.hs | 13 +++++++++++++ src/Data/Array/Nested/Ranked.hs | 2 ++ src/Data/Array/Nested/Ranked/Base.hs | 3 +++ src/Data/Array/Nested/Shaped.hs | 2 ++ src/Data/Array/Nested/Shaped/Base.hs | 3 +++ src/Data/Array/Strided/Orthotope.hs | 5 +++++ 6 files changed, 28 insertions(+) (limited to 'src/Data/Array') diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index eb05eaa..2b5c5b6 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -237,11 +237,13 @@ instance Elt a => NFData (Mixed sh a) where rnf = mrnf +{-# INLINE mliftNumElt1 #-} mliftNumElt1 :: (PrimElt a, PrimElt b) => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b) -> Mixed sh a -> Mixed sh b mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr)) +{-# INLINE mliftNumElt2 #-} mliftNumElt2 :: (PrimElt a, PrimElt b, PrimElt c) => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b -> S.Array (Rank sh) c) -> Mixed sh a -> Mixed sh b -> Mixed sh c @@ -417,6 +419,7 @@ instance Storable a => Elt (Primitive a) where in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l))) mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) + {-# INLINE mlift #-} mlift :: forall sh1 sh2. StaticShX sh2 -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) @@ -427,6 +430,7 @@ instance Storable a => Elt (Primitive a) where , let result = f ZKX a = M_Primitive (X.shape ssh2 result) result + {-# INLINE mlift2 #-} mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) @@ -438,6 +442,7 @@ instance Storable a => Elt (Primitive a) where , let result = f ZKX a b = M_Primitive (X.shape ssh3 result) result + {-# INLINE mliftL #-} mliftL :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) @@ -526,8 +531,11 @@ instance (Elt a, Elt b) => Elt (a, b) where M_Tup2 (mfromListOuterSN sn ((\(M_Tup2 x _) -> x) <$> l)) (mfromListOuterSN sn ((\(M_Tup2 _ y) -> y) <$> l)) mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b) + {-# INLINE mlift #-} mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) + {-# INLINE mlift2 #-} mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) + {-# INLINE mliftL #-} mliftL ssh2 f = let unzipT2l [] = ([], []) unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2) @@ -593,6 +601,7 @@ instance Elt a => Elt (Mixed sh' a) where mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr) + {-# INLINE mlift #-} mlift :: forall sh1 sh2. StaticShX sh2 -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) @@ -610,6 +619,7 @@ instance Elt a => Elt (Mixed sh' a) where , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) = f (ssxAppend ssh' sshT) + {-# INLINE mlift2 #-} mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) @@ -628,6 +638,7 @@ instance Elt a => Elt (Mixed sh' a) where , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) = f (ssxAppend ssh' sshT) + {-# INLINE mliftL #-} mliftL :: forall sh1 sh2. StaticShX sh2 -> (forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray (sh1 ++ shT) b) -> NonEmpty (XArray (sh2 ++ shT) b)) @@ -1066,11 +1077,13 @@ mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP +{-# INLINE mliftPrim #-} mliftPrim :: (PrimElt a, PrimElt b) => (a -> b) -> Mixed sh a -> Mixed sh b mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr)) +{-# INLINE mliftPrim2 #-} mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c) => (a -> b -> c) -> Mixed sh a -> Mixed sh b -> Mixed sh c diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs index ccbab63..8faff6d 100644 --- a/src/Data/Array/Nested/Ranked.hs +++ b/src/Data/Array/Nested/Ranked.hs @@ -79,6 +79,7 @@ rgeneratePrim sh f = in rfromVector sh $ VS.generate (shrSize sh) g -- | See the documentation of 'mlift'. +{-# INLINE rlift #-} rlift :: forall n1 n2 a. Elt a => SNat n2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) @@ -86,6 +87,7 @@ rlift :: forall n1 n2 a. Elt a rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr) -- | See the documentation of 'mlift2'. +{-# INLINE rlift2 #-} rlift2 :: forall n1 n2 n3 a. Elt a => SNat n3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b) diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs index 236cb05..9d88815 100644 --- a/src/Data/Array/Nested/Ranked/Base.hs +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -99,6 +99,7 @@ instance Elt a => Elt (Ranked n a) where mtoListOuter (M_Ranked arr) = coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr) + {-# INLINE mlift #-} mlift :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) @@ -107,6 +108,7 @@ instance Elt a => Elt (Ranked n a) where coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ mlift ssh2 f arr + {-# INLINE mlift2 #-} mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) @@ -115,6 +117,7 @@ instance Elt a => Elt (Ranked n a) where coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ mlift2 ssh3 f arr1 arr2 + {-# INLINE mliftL #-} mliftL :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 23a4fc8..5c52220 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -81,6 +81,7 @@ sgeneratePrim sh f = in sfromVector sh $ VS.generate (shsSize sh) g -- | See the documentation of 'mlift'. +{-# INLINE slift #-} slift :: forall sh1 sh2 a. Elt a => ShS sh2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) @@ -88,6 +89,7 @@ slift :: forall sh1 sh2 a. Elt a slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShX (shxFromShS sh2)) f arr) -- | See the documentation of 'mlift'. +{-# INLINE slift2 #-} slift2 :: forall sh1 sh2 sh3 a. Elt a => ShS sh3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index b86bfe5..8ef61dd 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -96,6 +96,7 @@ instance Elt a => Elt (Shaped sh a) where mtoListOuter (M_Shaped arr) = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) + {-# INLINE mlift #-} mlift :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) @@ -104,6 +105,7 @@ instance Elt a => Elt (Shaped sh a) where coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ mlift ssh2 f arr + {-# INLINE mlift2 #-} mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) @@ -112,6 +114,7 @@ instance Elt a => Elt (Shaped sh a) where coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ mlift2 ssh3 f arr1 arr2 + {-# INLINE mliftL #-} mliftL :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) diff --git a/src/Data/Array/Strided/Orthotope.hs b/src/Data/Array/Strided/Orthotope.hs index 5c38d14..e2cd17c 100644 --- a/src/Data/Array/Strided/Orthotope.hs +++ b/src/Data/Array/Strided/Orthotope.hs @@ -24,14 +24,19 @@ fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset ve toO :: AS.Array n a -> RS.Array n a toO (AS.Array sh strides offset vec) = RS.A (RG.A sh (OI.T strides offset vec)) +{-# INLINE liftO1 #-} liftO1 :: (AS.Array n a -> AS.Array n' b) -> RS.Array n a -> RS.Array n' b liftO1 f = toO . f . fromO +{-# INLINE liftO2 #-} liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c) -> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c liftO2 f x y = toO (f (fromO x) (fromO y)) +-- We don't inline this lifting function, because its code is not just +-- a wrapper, being relatively long and expensive. +{-# INLINEABLE liftVEltwise1 #-} liftVEltwise1 :: (Storable a, Storable b) => SNat n -> (VS.Vector a -> VS.Vector b) -- cgit v1.2.3-70-g09d2