aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Shaped.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-17 22:16:11 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-17 22:21:31 +0100
commitba798129655503d7e69de271d956cceaef4cef56 (patch)
treea906aaf03581060307a18d8347b8e4e5cef82a54 /src/Data/Array/Nested/Shaped.hs
parent0766e22df98179ce7debb179e544716bccfbca24 (diff)
Provide explicit-length versions of fromList functions
Diffstat (limited to 'src/Data/Array/Nested/Shaped.hs')
-rw-r--r--src/Data/Array/Nested/Shaped.hs40
1 files changed, 26 insertions, 14 deletions
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index 198a068..4a3ed8d 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -123,26 +123,38 @@ stoVectorP = coerce mtoVectorP
stoVector :: PrimElt a => Shaped sh a -> VS.Vector a
stoVector = coerce mtoVector
-sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
-sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1
-
+-- | All arrays in the list, even subarrays inside @a@, must have the same
+-- shape; if they do not, a runtime error will be thrown. See the
+-- documentation of 'mgenerate' for more information about this restriction.
+--
+-- Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'sfromListOuterSN' to be able to stream the list.
+--
+-- If your array is 1-dimensional and contains scalars, use 'sfromList1Prim'.
sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
-sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l))
+sfromListOuter = coerce mfromListOuterSN
+
+-- | Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'sfromList1SN' to be able to stream the list.
+--
+-- If the elements are scalars, 'sfromList1Prim' is faster.
+sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
+sfromList1 = coerce mfromList1SN
+-- | If the elements are scalars, 'sfromListPrimLinear' is faster.
sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a
sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l)
-sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
-sfromListPrim sn l
- | Refl <- lemAppNil @'[Just n]
- = let ssh = SUnknown () :!% ZKX
- xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l)
- in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr
+-- | Because the length of the list is unknown, its spine must be materialised
+-- in memory in order to compute its length. If its length is already known,
+-- use 'sfromList1PrimN' to be able to stream the list.
+sfromList1Prim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
+sfromList1Prim = coerce mfromList1PrimSN
-sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a
-sfromListPrimLinear sh l =
- let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
- in Shaped $ fromPrimitive $ M_Primitive (shxFromShS sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShS sh) xarr)
+sfromListPrimLinear :: forall sh a. PrimElt a => ShS sh -> [a] -> Shaped sh a
+sfromListPrimLinear sh l = Shaped (mfromListPrimLinear (shxFromShS sh) l)
stoList :: Elt a => Shaped '[n] a -> [a]
stoList = map sunScalar . stoListOuter