aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-06-04 14:51:27 +0200
committerTom Smeding <tom@tomsmeding.com>2025-06-04 14:51:27 +0200
commit8b698856bdef15def2681ee9cc97a4f5d6d52d54 (patch)
treec472d66e53441102ef51ad2be554d1340ced5ca4 /src
parent8bc791a6b5a725e1fa3699a2c260eacb51a4e5fa (diff)
Reorganise and clean up {from,to}List functions
Diffstat (limited to 'src')
-rw-r--r--src/Data/Array/Nested.hs15
-rw-r--r--src/Data/Array/Nested/Mixed.hs18
-rw-r--r--src/Data/Array/Nested/Ranked.hs32
-rw-r--r--src/Data/Array/Nested/Shaped.hs21
4 files changed, 37 insertions, 49 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 1ad2559..bb22d29 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -10,8 +10,9 @@ module Data.Array.Nested (
rtranspose, rappend, rconcat, rscalar, rfromVector, rtoVector, runScalar,
remptyArray,
rrerank,
- rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,
- rfromListLinear, rfromListPrimLinear, rtoListLinear,
+ rreplicate, rreplicateScal,
+ rfromList1, rfromListOuter, rfromListLinear, rfromListPrim, rfromListPrimLinear,
+ rtoList, rtoListOuter, rtoListLinear,
rslice, rrev1, rreshape, rflatten, riota,
rminIndexPrim, rmaxIndexPrim, rdot1Inner, rdot,
rnest, runNest, rzip, runzip,
@@ -36,8 +37,9 @@ module Data.Array.Nested (
-- TODO: sconcat? What should its type be?
semptyArray,
srerank,
- sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1,
- sfromListLinear, sfromListPrimLinear, stoListLinear,
+ sreplicate, sreplicateScal,
+ sfromList1, sfromListOuter, sfromListLinear, sfromListPrim, sfromListPrimLinear,
+ stoList, stoListOuter, stoListLinear,
sslice, srev1, sreshape, sflatten, siota,
sminIndexPrim, smaxIndexPrim, sdot1Inner, sdot,
snest, sunNest, szip, sunzip,
@@ -63,8 +65,9 @@ module Data.Array.Nested (
mtranspose, mappend, mconcat, mscalar, mfromVector, mtoVector, munScalar,
memptyArray,
mrerank,
- mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1,
- mfromListLinear, mfromListPrimLinear, mtoListLinear,
+ mreplicate, mreplicateScal,
+ mfromList1, mfromListOuter, mfromListLinear, mfromListPrim, mfromListPrimLinear,
+ mtoList, mtoListOuter, mtoListLinear,
mslice, mrev1, mreshape, mflatten, miota,
mminIndexPrim, mmaxIndexPrim, mdot1Inner, mdot,
mnest, munNest, mzip, munzip,
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 221393f..9ec8d9d 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -784,14 +784,10 @@ mtoVector arr = mtoVectorP (toPrimitive arr)
mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a
mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise?
-mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a
-mfromList1Prim l =
- let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
- in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
-
-mtoList1 :: Elt a => Mixed '[n] a -> [a]
-mtoList1 = map munScalar . mtoListOuter
+-- This forall is there so that a simple type application can constrain the
+-- shape, in case the user wants to use OverloadedLists for the shape.
+mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a
+mfromListLinear sh l = mreshape sh (mfromList1 l)
mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a
mfromListPrim l =
@@ -804,10 +800,8 @@ mfromListPrimLinear sh l =
let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr)
--- This forall is there so that a simple type application can constrain the
--- shape, in case the user wants to use OverloadedLists for the shape.
-mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a
-mfromListLinear sh l = mreshape sh (mfromList1 l)
+mtoList :: Elt a => Mixed '[n] a -> [a]
+mtoList = map munScalar . mtoListOuter
mtoListLinear :: Elt a => Mixed sh a -> [a]
mtoListLinear arr = map (mindex arr) (shxEnum (mshape arr)) -- TODO: optimise
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
index e5b8970..8591af7 100644
--- a/src/Data/Array/Nested/Ranked.hs
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -137,38 +137,32 @@ rtoVectorP = coerce mtoVectorP
rtoVector :: PrimElt a => Ranked n a -> VS.Vector a
rtoVector = coerce mtoVector
-rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
-rfromListOuter l
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
-
rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a
rfromList1 l = Ranked (mfromList1 l)
-rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a
-rfromList1Prim l = Ranked (mfromList1Prim l)
-
-rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
-rtoListOuter (Ranked arr)
+rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
+rfromListOuter l
| Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)
+ = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
-rtoList1 :: Elt a => Ranked 1 a -> [a]
-rtoList1 = map runScalar . rtoListOuter
+rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a
+rfromListLinear sh l = rreshape sh (rfromList1 l)
rfromListPrim :: PrimElt a => [a] -> Ranked 1 a
-rfromListPrim l =
- let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
- in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+rfromListPrim l = Ranked (mfromListPrim l)
rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a
rfromListPrimLinear sh l =
let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
in Ranked $ fromPrimitive $ M_Primitive (shxFromShR sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShR sh) xarr)
-rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a
-rfromListLinear sh l = rreshape sh (rfromList1 l)
+rtoList :: Elt a => Ranked 1 a -> [a]
+rtoList = map runScalar . rtoListOuter
+
+rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
+rtoListOuter (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)
rtoListLinear :: Elt a => Ranked n a -> [a]
rtoListLinear (Ranked arr) = mtoListLinear arr
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index 01982a8..aaba367 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -123,20 +123,14 @@ stoVectorP = coerce mtoVectorP
stoVector :: PrimElt a => Shaped sh a -> VS.Vector a
stoVector = coerce mtoVector
-sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
-sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l))
-
sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1
-sfromList1Prim :: PrimElt a => SNat n -> [a] -> Shaped '[n] a
-sfromList1Prim sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1Prim
-
-stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
-stoListOuter (Shaped arr) = coerce (mtoListOuter arr)
+sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
+sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l))
-stoList1 :: Elt a => Shaped '[n] a -> [a]
-stoList1 = map sunScalar . stoListOuter
+sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a
+sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l)
sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
sfromListPrim sn l
@@ -150,8 +144,11 @@ sfromListPrimLinear sh l =
let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
in Shaped $ fromPrimitive $ M_Primitive (shxFromShS sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShS sh) xarr)
-sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a
-sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l)
+stoList :: Elt a => Shaped '[n] a -> [a]
+stoList = map sunScalar . stoListOuter
+
+stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
+stoListOuter (Shaped arr) = coerce (mtoListOuter arr)
stoListLinear :: Elt a => Shaped sh a -> [a]
stoListLinear (Shaped arr) = mtoListLinear arr