aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-20 22:55:59 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-20 22:55:59 +0200
commitbeec5f9ab3f37da78bb95e2631871b4e47b46c57 (patch)
tree8efeae83646526ec7bdeaeb1674fa62f3cea2676 /src/Data/Array/Mixed.hs
parent8d495b7e6c21fc843f0538711c2203dfb213b7e1 (diff)
Better {from,to}List set
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r--src/Data/Array/Mixed.hs28
1 files changed, 20 insertions, 8 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 2f23903..1dc6b58 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -728,18 +728,30 @@ sumOuter ssh ssh'
| Refl <- lemAppNil @sh
= sumInner ssh' ssh . transpose2 ssh ssh'
-fromList1 :: forall n sh a. Storable a
- => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
-fromList1 ssh l
+fromListOuter :: forall n sh a. Storable a
+ => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
+fromListOuter ssh l
| Dict <- lemKnownNatRankSSX ssh
= case ssh of
- SKnown m@SNat :!% _ | natVal m /= fromIntegral (length l) ->
- error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++
- "does not match the type (" ++ show (natVal m) ++ ")"
+ SKnown m :!% _ | fromSNat' m /= length l ->
+ error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++
+ "does not match the type (" ++ show (fromSNat' m) ++ ")"
_ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l)))
-toList1 :: Storable a => XArray (n : sh) a -> [XArray sh a]
-toList1 (XArray arr) = coerce (ORB.toList (S.unravel arr))
+toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a]
+toListOuter (XArray arr) = coerce (ORB.toList (S.unravel arr))
+
+fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a
+fromList1 ssh l =
+ let n = length l
+ in case ssh of
+ SKnown m :!% _ | fromSNat' m /= n ->
+ error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++
+ "does not match the type (" ++ show (fromSNat' m) ++ ")"
+ _ -> XArray (S.fromVector [n] (VS.fromListN n l))
+
+toList1 :: Storable a => XArray '[n] a -> [a]
+toList1 (XArray arr) = S.toList arr
-- | Throws if the given shape is not, in fact, empty.
empty :: forall sh a. Storable a => IShX sh -> XArray sh a