aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Shaped.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Shaped.hs')
-rw-r--r--src/Data/Array/Nested/Shaped.hs24
1 files changed, 17 insertions, 7 deletions
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index 99ad590..36ef24a 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -56,7 +56,7 @@ ssize = shsSize . sshape
sindex :: Elt a => Shaped sh a -> IIxS sh -> a
sindex (Shaped arr) idx = mindex arr (ixxFromIxS idx)
-shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh
+shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IxS sh i -> ShS sh
shsTakeIx _ _ ZIS = ZSS
shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx
@@ -70,7 +70,7 @@ sindexPartial sarr@(Shaped arr) idx =
-- | __WARNING__: All values returned from the function must have equal shape.
-- See the documentation of 'mgenerate' for more details.
sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a
-sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh))
+sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX))
-- | See 'mgeneratePrim'.
{-# INLINE sgeneratePrim #-}
@@ -81,6 +81,7 @@ sgeneratePrim sh f =
in sfromVector sh $ VS.generate (shsSize sh) g
-- | See the documentation of 'mlift'.
+{-# INLINE slift #-}
slift :: forall sh1 sh2 a. Elt a
=> ShS sh2
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
@@ -88,23 +89,28 @@ slift :: forall sh1 sh2 a. Elt a
slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShX (shxFromShS sh2)) f arr)
-- | See the documentation of 'mlift'.
+{-# INLINE slift2 #-}
slift2 :: forall sh1 sh2 sh3 a. Elt a
=> ShS sh3
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b)
-> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a
slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShX (shxFromShS sh3)) f arr1 arr2)
+{-# INLINE ssumOuter1PrimP #-}
ssumOuter1PrimP :: forall sh n a. (Storable a, NumElt a)
=> Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
ssumOuter1PrimP (Shaped arr) = Shaped (msumOuter1PrimP arr)
+{-# INLINEABLE ssumOuter1Prim #-}
ssumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a)
=> Shaped (n : sh) a -> Shaped sh a
ssumOuter1Prim = sfromPrimitive . ssumOuter1PrimP . stoPrimitive
+{-# INLINE ssumAllPrimP #-}
ssumAllPrimP :: (PrimElt a, NumElt a) => Shaped n (Primitive a) -> a
ssumAllPrimP (Shaped arr) = msumAllPrimP arr
+{-# INLINE ssumAllPrim #-}
ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a
ssumAllPrim (Shaped arr) = msumAllPrim arr
@@ -124,15 +130,19 @@ sappend = coerce mappend
sscalar :: Elt a => a -> Shaped '[] a
sscalar x = Shaped (mscalar x)
+{-# INLINEABLE sfromVectorP #-}
sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a)
sfromVectorP sh v = Shaped (mfromVectorP (shxFromShS sh) v)
+{-# INLINEABLE sfromVector #-}
sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a
sfromVector sh v = sfromPrimitive (sfromVectorP sh v)
+{-# INLINEABLE stoVectorP #-}
stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a
stoVectorP = coerce mtoVectorP
+{-# INLINEABLE stoVector #-}
stoVector :: PrimElt a => Shaped sh a -> VS.Vector a
stoVector = coerce mtoVector
@@ -246,21 +256,20 @@ sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shape
sreshape sh' (Shaped arr) = Shaped (mreshape (shxFromShS sh') arr)
sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a
-sflatten arr =
- case shsProduct (sshape arr) of -- TODO: simplify when removing the KnownNat stuff
- n@SNat -> sreshape (n :$$ ZSS) arr
+sflatten arr = sreshape (shsProduct (sshape arr) :$$ ZSS) arr
siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a
siota sn = Shaped (miota sn)
-- | Throws if the array is empty.
sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
-sminIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mminIndexPrim arr)
+sminIndexPrim (Shaped arr) = ixsFromIxX (mminIndexPrim arr)
-- | Throws if the array is empty.
smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
-smaxIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr)
+smaxIndexPrim (Shaped arr) = ixsFromIxX (mmaxIndexPrim arr)
+{-# INLINEABLE sdot1Inner #-}
sdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
=> Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a
sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2)
@@ -272,6 +281,7 @@ sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2)
-> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2)
_ -> error "unreachable"
+{-# INLINE sdot #-}
-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
-- Prefer 'sdot1Inner' if applicable.
sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a