aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Shaped.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Shaped.hs')
-rw-r--r--src/Data/Array/Nested/Shaped.hs122
1 files changed, 72 insertions, 50 deletions
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index 01982a8..99ad590 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -42,7 +42,7 @@ import Data.Array.XArray (XArray)
import Data.Array.XArray qualified as X
-semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a
+semptyArray :: forall sh a. KnownElt a => ShS sh -> Shaped (0 : sh) a
semptyArray sh = Shaped (memptyArray (shxFromShS sh))
srank :: Elt a => Shaped sh a -> SNat (Rank sh)
@@ -52,6 +52,7 @@ srank = shsRank . sshape
ssize :: Elt a => Shaped sh a -> Int
ssize = shsSize . sshape
+{-# INLINEABLE sindex #-}
sindex :: Elt a => Shaped sh a -> IIxS sh -> a
sindex (Shaped arr) idx = mindex arr (ixxFromIxS idx)
@@ -59,6 +60,7 @@ shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh
shsTakeIx _ _ ZIS = ZSS
shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx
+{-# INLINEABLE sindexPartial #-}
sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a
sindexPartial sarr@(Shaped arr) idx =
Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
@@ -70,6 +72,14 @@ sindexPartial sarr@(Shaped arr) idx =
sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a
sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh))
+-- | See 'mgeneratePrim'.
+{-# INLINE sgeneratePrim #-}
+sgeneratePrim :: forall sh a i. (PrimElt a, Num i)
+ => ShS sh -> (IxS sh i -> a) -> Shaped sh a
+sgeneratePrim sh f =
+ let g i = f (ixsFromLinear sh i)
+ in sfromVector sh $ VS.generate (shsSize sh) g
+
-- | See the documentation of 'mlift'.
slift :: forall sh1 sh2 a. Elt a
=> ShS sh2
@@ -84,13 +94,16 @@ slift2 :: forall sh1 sh2 sh3 a. Elt a
-> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a
slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShX (shxFromShS sh3)) f arr1 arr2)
-ssumOuter1P :: forall sh n a. (Storable a, NumElt a)
- => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
-ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr)
+ssumOuter1PrimP :: forall sh n a. (Storable a, NumElt a)
+ => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
+ssumOuter1PrimP (Shaped arr) = Shaped (msumOuter1PrimP arr)
+
+ssumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a)
+ => Shaped (n : sh) a -> Shaped sh a
+ssumOuter1Prim = sfromPrimitive . ssumOuter1PrimP . stoPrimitive
-ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
- => Shaped (n : sh) a -> Shaped sh a
-ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive
+ssumAllPrimP :: (PrimElt a, NumElt a) => Shaped n (Primitive a) -> a
+ssumAllPrimP (Shaped arr) = msumAllPrimP arr
ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a
ssumAllPrim (Shaped arr) = msumAllPrim arr
@@ -123,35 +136,44 @@ stoVectorP = coerce mtoVectorP
stoVector :: PrimElt a => Shaped sh a -> VS.Vector a
stoVector = coerce mtoVector
+-- | All arrays in the list, even subarrays inside @a@, must have the same
+-- shape; if they do not, a runtime error will be thrown. See the
+-- documentation of 'mgenerate' for more information about this restriction.
+--
+-- Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'sfromListOuterSN' to be able to stream the list.
+--
+-- If your array is 1-dimensional and contains scalars, use 'sfromList1Prim'.
sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
-sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l))
+sfromListOuter = coerce mfromListOuterSN
+-- | Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'sfromList1SN' to be able to stream the list.
+--
+-- If the elements are scalars, 'sfromList1Prim' is faster.
sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
-sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1
-
-sfromList1Prim :: PrimElt a => SNat n -> [a] -> Shaped '[n] a
-sfromList1Prim sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1Prim
+sfromList1 = coerce mfromList1SN
-stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
-stoListOuter (Shaped arr) = coerce (mtoListOuter arr)
+-- | If the elements are scalars, 'sfromListPrimLinear' is faster.
+sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a
+sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l)
-stoList1 :: Elt a => Shaped '[n] a -> [a]
-stoList1 = map sunScalar . stoListOuter
+-- | Because the length of the list is unknown, its spine must be materialised
+-- in memory in order to compute its length. If its length is already known,
+-- use 'sfromList1PrimN' to be able to stream the list.
+sfromList1Prim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
+sfromList1Prim = coerce mfromList1PrimSN
-sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
-sfromListPrim sn l
- | Refl <- lemAppNil @'[Just n]
- = let ssh = SUnknown () :!% ZKX
- xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l)
- in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr
+sfromListPrimLinear :: forall sh a. PrimElt a => ShS sh -> [a] -> Shaped sh a
+sfromListPrimLinear sh l = Shaped (mfromListPrimLinear (shxFromShS sh) l)
-sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a
-sfromListPrimLinear sh l =
- let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
- in Shaped $ fromPrimitive $ M_Primitive (shxFromShS sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShS sh) xarr)
+stoList :: Elt a => Shaped '[n] a -> [a]
+stoList = map sunScalar . stoListOuter
-sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a
-sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l)
+stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
+stoListOuter (Shaped arr) = coerce (mtoListOuter arr)
stoListLinear :: Elt a => Shaped sh a -> [a]
stoListLinear (Shaped arr) = mtoListLinear arr
@@ -176,41 +198,41 @@ sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr)))
| Refl <- lemMapJustApp (sshape sarr) (Proxy @sh')
= Shaped arr
-szip :: Shaped sh a -> Shaped sh b -> Shaped sh (a, b)
+szip :: (Elt a, Elt b) => Shaped sh a -> Shaped sh b -> Shaped sh (a, b)
szip = coerce mzip
sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b)
sunzip = coerce munzip
-srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
- => ShS sh -> ShS sh2
- -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b))
- -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b)
-srerankP sh sh2 f sarr@(Shaped arr)
- | Refl <- lemMapJustApp sh (Proxy @sh1)
- , Refl <- lemMapJustApp sh (Proxy @sh2)
- = Shaped (mrerankP (ssxFromShX (shxTakeSSX (Proxy @(MapJust sh1)) (shxFromShS (sshape sarr)) (ssxFromShX (shxFromShS sh))))
- (shxFromShS sh2)
- (\a -> let Shaped r = f (Shaped a) in r)
- arr)
+srerankPrimP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
+ => ShS sh2
+ -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b))
+ -> Shaped sh (Shaped sh1 (Primitive a)) -> Shaped sh (Shaped sh2 (Primitive b))
+srerankPrimP sh2 f (Shaped (M_Shaped arr))
+ = Shaped (M_Shaped (mrerankPrimP (shxFromShS sh2)
+ (\a -> let Shaped r = f (Shaped a) in r)
+ arr))
-srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
- => ShS sh -> ShS sh2
- -> (Shaped sh1 a -> Shaped sh2 b)
- -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b
-srerank sh sh2 f (stoPrimitive -> arr) =
- sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr
+-- | See the caveats at 'mrerankPrim'.
+srerankPrim :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
+ => ShS sh2
+ -> (Shaped sh1 a -> Shaped sh2 b)
+ -> Shaped sh (Shaped sh1 a) -> Shaped sh (Shaped sh2 b)
+srerankPrim sh2 f (Shaped (M_Shaped arr)) =
+ Shaped (M_Shaped (mrerankPrim (shxFromShS sh2)
+ (\a -> let Shaped r = f (Shaped a) in r)
+ arr))
sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a
sreplicate sh (Shaped arr)
| Refl <- lemMapJustApp sh (Proxy @sh')
= Shaped (mreplicate (shxFromShS sh) arr)
-sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
-sreplicateScalP sh x = Shaped (mreplicateScalP (shxFromShS sh) x)
+sreplicatePrimP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
+sreplicatePrimP sh x = Shaped (mreplicatePrimP (shxFromShS sh) x)
-sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a
-sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x)
+sreplicatePrim :: forall sh a. PrimElt a => ShS sh -> a -> Shaped sh a
+sreplicatePrim sh x = sfromPrimitive (sreplicatePrimP sh x)
sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a
sslice i n@SNat arr =