aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Ranked.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Ranked.hs')
-rw-r--r--src/Data/Array/Nested/Ranked.hs57
1 files changed, 43 insertions, 14 deletions
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
index 8b95d0f..5cda531 100644
--- a/src/Data/Array/Nested/Ranked.hs
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -137,24 +137,55 @@ rtoVectorP = coerce mtoVectorP
rtoVector :: PrimElt a => Ranked n a -> VS.Vector a
rtoVector = coerce mtoVector
-rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a
-rfromList1 l = Ranked (mfromList1 l)
-
+-- | All arrays in the list, even subarrays inside @a@, must have the same
+-- shape; if they do not, a runtime error will be thrown. See the
+-- documentation of 'mgenerate' for more information about this restriction.
+--
+-- Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'rfromListOuterN' to be able to stream the list.
+--
+-- If your array is 1-dimensional and contains scalars, use 'rfromList1Prim'.
rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
rfromListOuter l
| Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
= Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
+-- | See 'rfromListOuter'. If the list does not have the given length, a
+-- runtime error is thrown. 'rfromList1PrimN' is faster if applicable.
+rfromListOuterN :: forall n a. Elt a => Int -> NonEmpty (Ranked n a) -> Ranked (n + 1) a
+rfromListOuterN n l
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = Ranked (mfromListOuterN n (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
+
+-- | Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'rfromList1N' to be able to stream the list.
+--
+-- If the elements are scalars, 'rfromList1Prim' is faster.
+rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a
+rfromList1 = coerce mfromList1
+
+-- | If the elements are scalars, 'rfromList1PrimN' is faster. A runtime error
+-- is thrown if the list length does not match the given length.
+rfromList1N :: Elt a => Int -> NonEmpty a -> Ranked 1 a
+rfromList1N = coerce mfromList1N
+
+-- | If the elements are scalars, 'rfromListPrimLinear' is faster.
rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a
-rfromListLinear sh l = rreshape sh (rfromList1 l)
+rfromListLinear sh l = Ranked (mfromListLinear (shxFromShR sh) l)
+
+-- | Because the length of the list is unknown, its spine must be materialised
+-- in memory in order to compute its length. If its length is already known,
+-- use 'rfromList1PrimN' to be able to stream the list.
+rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a
+rfromList1Prim = coerce mfromList1Prim
-rfromListPrim :: PrimElt a => [a] -> Ranked 1 a
-rfromListPrim l = Ranked (mfromListPrim l)
+rfromList1PrimN :: PrimElt a => Int -> [a] -> Ranked 1 a
+rfromList1PrimN = coerce mfromList1PrimN
-rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a
-rfromListPrimLinear sh l =
- let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
- in Ranked $ fromPrimitive $ M_Primitive (shxFromShR sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShR sh) xarr)
+rfromListPrimLinear :: forall n a. PrimElt a => IShR n -> [a] -> Ranked n a
+rfromListPrimLinear sh l = Ranked (mfromListPrimLinear (shxFromShR sh) l)
rtoList :: Elt a => Ranked 1 a -> [a]
rtoList = map runScalar . rtoListOuter
@@ -254,11 +285,9 @@ rreplicateScal :: forall n a. PrimElt a
rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)
rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
-rslice i n arr
+rslice i n (Ranked arr)
| Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
- = rlift (rrank arr)
- (\_ -> X.sliceU i n)
- arr
+ = Ranked (msliceN i n arr)
rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a
rrev1 arr =