aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Shaped.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Shaped.hs')
-rw-r--r--src/Data/Array/Nested/Shaped.hs10
1 files changed, 10 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index 5c52220..36ef24a 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -96,17 +96,21 @@ slift2 :: forall sh1 sh2 sh3 a. Elt a
-> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a
slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShX (shxFromShS sh3)) f arr1 arr2)
+{-# INLINE ssumOuter1PrimP #-}
ssumOuter1PrimP :: forall sh n a. (Storable a, NumElt a)
=> Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
ssumOuter1PrimP (Shaped arr) = Shaped (msumOuter1PrimP arr)
+{-# INLINEABLE ssumOuter1Prim #-}
ssumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a)
=> Shaped (n : sh) a -> Shaped sh a
ssumOuter1Prim = sfromPrimitive . ssumOuter1PrimP . stoPrimitive
+{-# INLINE ssumAllPrimP #-}
ssumAllPrimP :: (PrimElt a, NumElt a) => Shaped n (Primitive a) -> a
ssumAllPrimP (Shaped arr) = msumAllPrimP arr
+{-# INLINE ssumAllPrim #-}
ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a
ssumAllPrim (Shaped arr) = msumAllPrim arr
@@ -126,15 +130,19 @@ sappend = coerce mappend
sscalar :: Elt a => a -> Shaped '[] a
sscalar x = Shaped (mscalar x)
+{-# INLINEABLE sfromVectorP #-}
sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a)
sfromVectorP sh v = Shaped (mfromVectorP (shxFromShS sh) v)
+{-# INLINEABLE sfromVector #-}
sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a
sfromVector sh v = sfromPrimitive (sfromVectorP sh v)
+{-# INLINEABLE stoVectorP #-}
stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a
stoVectorP = coerce mtoVectorP
+{-# INLINEABLE stoVector #-}
stoVector :: PrimElt a => Shaped sh a -> VS.Vector a
stoVector = coerce mtoVector
@@ -261,6 +269,7 @@ sminIndexPrim (Shaped arr) = ixsFromIxX (mminIndexPrim arr)
smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
smaxIndexPrim (Shaped arr) = ixsFromIxX (mmaxIndexPrim arr)
+{-# INLINEABLE sdot1Inner #-}
sdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
=> Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a
sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2)
@@ -272,6 +281,7 @@ sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2)
-> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2)
_ -> error "unreachable"
+{-# INLINE sdot #-}
-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
-- Prefer 'sdot1Inner' if applicable.
sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a