aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Nested/Mixed.hs13
-rw-r--r--src/Data/Array/Nested/Ranked.hs2
-rw-r--r--src/Data/Array/Nested/Ranked/Base.hs3
-rw-r--r--src/Data/Array/Nested/Shaped.hs2
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs3
-rw-r--r--src/Data/Array/Strided/Orthotope.hs5
6 files changed, 28 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index eb05eaa..2b5c5b6 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -237,11 +237,13 @@ instance Elt a => NFData (Mixed sh a) where
rnf = mrnf
+{-# INLINE mliftNumElt1 #-}
mliftNumElt1 :: (PrimElt a, PrimElt b)
=> (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b)
-> Mixed sh a -> Mixed sh b
mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr))
+{-# INLINE mliftNumElt2 #-}
mliftNumElt2 :: (PrimElt a, PrimElt b, PrimElt c)
=> (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b -> S.Array (Rank sh) c)
-> Mixed sh a -> Mixed sh b -> Mixed sh c
@@ -417,6 +419,7 @@ instance Storable a => Elt (Primitive a) where
in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l)))
mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr)
+ {-# INLINE mlift #-}
mlift :: forall sh1 sh2.
StaticShX sh2
-> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
@@ -427,6 +430,7 @@ instance Storable a => Elt (Primitive a) where
, let result = f ZKX a
= M_Primitive (X.shape ssh2 result) result
+ {-# INLINE mlift2 #-}
mlift2 :: forall sh1 sh2 sh3.
StaticShX sh3
-> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a)
@@ -438,6 +442,7 @@ instance Storable a => Elt (Primitive a) where
, let result = f ZKX a b
= M_Primitive (X.shape ssh3 result) result
+ {-# INLINE mliftL #-}
mliftL :: forall sh1 sh2.
StaticShX sh2
-> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
@@ -526,8 +531,11 @@ instance (Elt a, Elt b) => Elt (a, b) where
M_Tup2 (mfromListOuterSN sn ((\(M_Tup2 x _) -> x) <$> l))
(mfromListOuterSN sn ((\(M_Tup2 _ y) -> y) <$> l))
mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b)
+ {-# INLINE mlift #-}
mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b)
+ {-# INLINE mlift2 #-}
mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y)
+ {-# INLINE mliftL #-}
mliftL ssh2 f =
let unzipT2l [] = ([], [])
unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2)
@@ -593,6 +601,7 @@ instance Elt a => Elt (Mixed sh' a) where
mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr)
+ {-# INLINE mlift #-}
mlift :: forall sh1 sh2.
StaticShX sh2
-> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)
@@ -610,6 +619,7 @@ instance Elt a => Elt (Mixed sh' a) where
, Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
= f (ssxAppend ssh' sshT)
+ {-# INLINE mlift2 #-}
mlift2 :: forall sh1 sh2 sh3.
StaticShX sh3
-> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b)
@@ -628,6 +638,7 @@ instance Elt a => Elt (Mixed sh' a) where
, Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT)
= f (ssxAppend ssh' sshT)
+ {-# INLINE mliftL #-}
mliftL :: forall sh1 sh2.
StaticShX sh2
-> (forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray (sh1 ++ shT) b) -> NonEmpty (XArray (sh2 ++ shT) b))
@@ -1066,11 +1077,13 @@ mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr
mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a
mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP
+{-# INLINE mliftPrim #-}
mliftPrim :: (PrimElt a, PrimElt b)
=> (a -> b)
-> Mixed sh a -> Mixed sh b
mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr))
+{-# INLINE mliftPrim2 #-}
mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c)
=> (a -> b -> c)
-> Mixed sh a -> Mixed sh b -> Mixed sh c
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
index ccbab63..8faff6d 100644
--- a/src/Data/Array/Nested/Ranked.hs
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -79,6 +79,7 @@ rgeneratePrim sh f =
in rfromVector sh $ VS.generate (shrSize sh) g
-- | See the documentation of 'mlift'.
+{-# INLINE rlift #-}
rlift :: forall n1 n2 a. Elt a
=> SNat n2
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
@@ -86,6 +87,7 @@ rlift :: forall n1 n2 a. Elt a
rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr)
-- | See the documentation of 'mlift2'.
+{-# INLINE rlift2 #-}
rlift2 :: forall n1 n2 n3 a. Elt a
=> SNat n3
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b)
diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs
index 236cb05..9d88815 100644
--- a/src/Data/Array/Nested/Ranked/Base.hs
+++ b/src/Data/Array/Nested/Ranked/Base.hs
@@ -99,6 +99,7 @@ instance Elt a => Elt (Ranked n a) where
mtoListOuter (M_Ranked arr) =
coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr)
+ {-# INLINE mlift #-}
mlift :: forall sh1 sh2.
StaticShX sh2
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
@@ -107,6 +108,7 @@ instance Elt a => Elt (Ranked n a) where
coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $
mlift ssh2 f arr
+ {-# INLINE mlift2 #-}
mlift2 :: forall sh1 sh2 sh3.
StaticShX sh3
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
@@ -115,6 +117,7 @@ instance Elt a => Elt (Ranked n a) where
coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
mlift2 ssh3 f arr1 arr2
+ {-# INLINE mliftL #-}
mliftL :: forall sh1 sh2.
StaticShX sh2
-> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index 23a4fc8..5c52220 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -81,6 +81,7 @@ sgeneratePrim sh f =
in sfromVector sh $ VS.generate (shsSize sh) g
-- | See the documentation of 'mlift'.
+{-# INLINE slift #-}
slift :: forall sh1 sh2 a. Elt a
=> ShS sh2
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
@@ -88,6 +89,7 @@ slift :: forall sh1 sh2 a. Elt a
slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShX (shxFromShS sh2)) f arr)
-- | See the documentation of 'mlift'.
+{-# INLINE slift2 #-}
slift2 :: forall sh1 sh2 sh3 a. Elt a
=> ShS sh3
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b)
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
index b86bfe5..8ef61dd 100644
--- a/src/Data/Array/Nested/Shaped/Base.hs
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -96,6 +96,7 @@ instance Elt a => Elt (Shaped sh a) where
mtoListOuter (M_Shaped arr)
= coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr)
+ {-# INLINE mlift #-}
mlift :: forall sh1 sh2.
StaticShX sh2
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
@@ -104,6 +105,7 @@ instance Elt a => Elt (Shaped sh a) where
coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
mlift ssh2 f arr
+ {-# INLINE mlift2 #-}
mlift2 :: forall sh1 sh2 sh3.
StaticShX sh3
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
@@ -112,6 +114,7 @@ instance Elt a => Elt (Shaped sh a) where
coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
mlift2 ssh3 f arr1 arr2
+ {-# INLINE mliftL #-}
mliftL :: forall sh1 sh2.
StaticShX sh2
-> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
diff --git a/src/Data/Array/Strided/Orthotope.hs b/src/Data/Array/Strided/Orthotope.hs
index 5c38d14..e2cd17c 100644
--- a/src/Data/Array/Strided/Orthotope.hs
+++ b/src/Data/Array/Strided/Orthotope.hs
@@ -24,14 +24,19 @@ fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset ve
toO :: AS.Array n a -> RS.Array n a
toO (AS.Array sh strides offset vec) = RS.A (RG.A sh (OI.T strides offset vec))
+{-# INLINE liftO1 #-}
liftO1 :: (AS.Array n a -> AS.Array n' b)
-> RS.Array n a -> RS.Array n' b
liftO1 f = toO . f . fromO
+{-# INLINE liftO2 #-}
liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c)
-> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c
liftO2 f x y = toO (f (fromO x) (fromO y))
+-- We don't inline this lifting function, because its code is not just
+-- a wrapper, being relatively long and expensive.
+{-# INLINEABLE liftVEltwise1 #-}
liftVEltwise1 :: (Storable a, Storable b)
=> SNat n
-> (VS.Vector a -> VS.Vector b)