diff options
Diffstat (limited to 'src/Data/Array/Nested/Shaped.hs')
| -rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 5c52220..36ef24a 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -96,17 +96,21 @@ slift2 :: forall sh1 sh2 sh3 a. Elt a -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShX (shxFromShS sh3)) f arr1 arr2) +{-# INLINE ssumOuter1PrimP #-} ssumOuter1PrimP :: forall sh n a. (Storable a, NumElt a) => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) ssumOuter1PrimP (Shaped arr) = Shaped (msumOuter1PrimP arr) +{-# INLINEABLE ssumOuter1Prim #-} ssumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a) => Shaped (n : sh) a -> Shaped sh a ssumOuter1Prim = sfromPrimitive . ssumOuter1PrimP . stoPrimitive +{-# INLINE ssumAllPrimP #-} ssumAllPrimP :: (PrimElt a, NumElt a) => Shaped n (Primitive a) -> a ssumAllPrimP (Shaped arr) = msumAllPrimP arr +{-# INLINE ssumAllPrim #-} ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a ssumAllPrim (Shaped arr) = msumAllPrim arr @@ -126,15 +130,19 @@ sappend = coerce mappend sscalar :: Elt a => a -> Shaped '[] a sscalar x = Shaped (mscalar x) +{-# INLINEABLE sfromVectorP #-} sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) sfromVectorP sh v = Shaped (mfromVectorP (shxFromShS sh) v) +{-# INLINEABLE sfromVector #-} sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a sfromVector sh v = sfromPrimitive (sfromVectorP sh v) +{-# INLINEABLE stoVectorP #-} stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a stoVectorP = coerce mtoVectorP +{-# INLINEABLE stoVector #-} stoVector :: PrimElt a => Shaped sh a -> VS.Vector a stoVector = coerce mtoVector @@ -261,6 +269,7 @@ sminIndexPrim (Shaped arr) = ixsFromIxX (mminIndexPrim arr) smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh smaxIndexPrim (Shaped arr) = ixsFromIxX (mmaxIndexPrim arr) +{-# INLINEABLE sdot1Inner #-} sdot1Inner :: forall sh n a. (PrimElt a, NumElt a) => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2) @@ -272,6 +281,7 @@ sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2) -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2) _ -> error "unreachable" +{-# INLINE sdot #-} -- | This has a temporary, suboptimal implementation in terms of 'mflatten'. -- Prefer 'sdot1Inner' if applicable. sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a |
