aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Data/Array/XArray.hs36
1 files changed, 26 insertions, 10 deletions
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs
index 29154f1..892c3e0 100644
--- a/src/Data/Array/XArray.hs
+++ b/src/Data/Array/XArray.hs
@@ -1,3 +1,4 @@
+{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
@@ -296,10 +297,18 @@ fromListOuter :: forall n sh a. Storable a
fromListOuter ssh l
| Dict <- lemKnownNatRankSSX ssh
= case ssh of
- SKnown m :!% _ | fromSNat' m /= length l ->
- error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++
- "does not match the type (" ++ show (fromSNat' m) ++ ")"
- _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l)))
+ SKnown m :!% _ ->
+ let n = fromSNat' m
+ !t = ORB.fromList [n] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l)
+ -- forced to fuse the list away (or at least keep in memory at most one cell)
+ in if n /= length l
+ then error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++
+ "does not match the type (" ++ show n ++ ")"
+ else XArray (S.ravel t) -- TODO: do we break fusion with ravel by forcing t? Maybe remove the detour through boxed vectors altogether?
+ _ ->
+ let n = length l -- we rather force the list than allocate too long a vector with VS.fromList
+ t = ORB.fromList [n] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l)
+ in XArray (S.ravel t)
toListOuter :: forall a n sh. Storable a => XArray (n : sh) a -> [XArray sh a]
toListOuter (XArray arr@(ORS.A (ORG.A _ t))) =
@@ -312,12 +321,19 @@ toListOuter (XArray arr@(ORS.A (ORG.A _ t))) =
fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a
fromList1 ssh l =
- let n = length l
- in case ssh of
- SKnown m :!% _ | fromSNat' m /= n ->
- error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++
- "does not match the type (" ++ show (fromSNat' m) ++ ")"
- _ -> XArray (S.fromVector [n] (VS.fromListN n l))
+ case ssh of
+ SKnown m :!% _ ->
+ let n = fromSNat' m
+ !v = VS.fromListN n l
+ -- forced to fuse the list away (or at least keep in memory at most one cell)
+ in if n /= length l
+ then error $ "Data.Array.Mixed.fromList1: length of list (" ++ show (length l) ++ ")" ++
+ "does not match the type (" ++ show n ++ ")"
+ else XArray (S.fromVector [n] v)
+ _ ->
+ let n = length l -- we rather force the list than allocate too long a vector with VS.fromList
+ v = VS.fromListN n l
+ in XArray (S.fromVector [n] v)
toList1 :: Storable a => XArray '[n] a -> [a]
toList1 (XArray arr) = S.toList arr