aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Ranked.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Ranked.hs')
-rw-r--r--src/Data/Array/Nested/Ranked.hs18
1 files changed, 15 insertions, 3 deletions
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
index d687983..b448685 100644
--- a/src/Data/Array/Nested/Ranked.hs
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -79,6 +79,7 @@ rgeneratePrim sh f =
in rfromVector sh $ VS.generate (shrSize sh) g
-- | See the documentation of 'mlift'.
+{-# INLINE rlift #-}
rlift :: forall n1 n2 a. Elt a
=> SNat n2
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
@@ -86,12 +87,14 @@ rlift :: forall n1 n2 a. Elt a
rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr)
-- | See the documentation of 'mlift2'.
+{-# INLINE rlift2 #-}
rlift2 :: forall n1 n2 n3 a. Elt a
=> SNat n3
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b)
-> Ranked n1 a -> Ranked n2 a -> Ranked n3 a
rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2)
+{-# INLINE rsumOuter1PrimP #-}
rsumOuter1PrimP :: forall n a.
(Storable a, NumElt a)
=> Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
@@ -99,13 +102,16 @@ rsumOuter1PrimP (Ranked arr)
| Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
= Ranked (msumOuter1PrimP arr)
+{-# INLINEABLE rsumOuter1Prim #-}
rsumOuter1Prim :: forall n a. (NumElt a, PrimElt a)
=> Ranked (n + 1) a -> Ranked n a
rsumOuter1Prim = rfromPrimitive . rsumOuter1PrimP . rtoPrimitive
+{-# INLINE rsumAllPrimP #-}
rsumAllPrimP :: (Storable a, NumElt a) => Ranked n (Primitive a) -> a
rsumAllPrimP (Ranked arr) = msumAllPrimP arr
+{-# INLINE rsumAllPrim #-}
rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a
rsumAllPrim (Ranked arr) = msumAllPrim arr
@@ -137,17 +143,21 @@ rappend arr1 arr2
rscalar :: Elt a => a -> Ranked 0 a
rscalar x = Ranked (mscalar x)
+{-# INLINEABLE rfromVectorP #-}
rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a)
rfromVectorP sh v
| Dict <- lemKnownReplicate (shrRank sh)
= Ranked (mfromVectorP (shxFromShR sh) v)
+{-# INLINEABLE rfromVector #-}
rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a
rfromVector sh v = rfromPrimitive (rfromVectorP sh v)
+{-# INLINEABLE rtoVectorP #-}
rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a
rtoVectorP = coerce mtoVectorP
+{-# INLINEABLE rtoVector #-}
rtoVector :: PrimElt a => Ranked n a -> VS.Vector a
rtoVector = coerce mtoVector
@@ -220,7 +230,7 @@ rfromOrthotope sn arr
rtoOrthotope :: forall a n. PrimElt a => Ranked n a -> S.Array n a
rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr)))
- | Refl <- lemRankReplicate (shrRank $ shrFromShX2 @n sh)
+ | Refl <- lemRankReplicate (shrRank $ shrFromShX @n sh)
= arr
runScalar :: Elt a => Ranked 0 a -> a
@@ -333,6 +343,7 @@ rmaxIndexPrim rarr@(Ranked arr)
| Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
= ixrFromIxX (mmaxIndexPrim arr)
+{-# INLINEABLE rdot1Inner #-}
rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a
rdot1Inner arr1 arr2
| SNat <- rrank arr1
@@ -341,14 +352,15 @@ rdot1Inner arr1 arr2
-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
-- Prefer 'rdot1Inner' if applicable.
+{-# INLINE rdot #-}
rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a
rdot = coerce mdot
rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a)
-rtoXArrayPrimP (Ranked arr) = first shrFromShX2 (mtoXArrayPrimP arr)
+rtoXArrayPrimP (Ranked arr) = first shrFromShX (mtoXArrayPrimP arr)
rtoXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a)
-rtoXArrayPrim (Ranked arr) = first shrFromShX2 (mtoXArrayPrim arr)
+rtoXArrayPrim (Ranked arr) = first shrFromShX (mtoXArrayPrim arr)
rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a)
rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShX (X.shape (ssxFromSNat sn) arr)) arr)