diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/Data/Array/XArray.hs | 36 |
1 files changed, 26 insertions, 10 deletions
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index 29154f1..892c3e0 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} @@ -296,10 +297,18 @@ fromListOuter :: forall n sh a. Storable a fromListOuter ssh l | Dict <- lemKnownNatRankSSX ssh = case ssh of - SKnown m :!% _ | fromSNat' m /= length l -> - error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++ - "does not match the type (" ++ show (fromSNat' m) ++ ")" - _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) + SKnown m :!% _ -> + let n = fromSNat' m + !t = ORB.fromList [n] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l) + -- forced to fuse the list away (or at least keep in memory at most one cell) + in if n /= length l + then error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++ + "does not match the type (" ++ show n ++ ")" + else XArray (S.ravel t) -- TODO: do we break fusion with ravel by forcing t? Maybe remove the detour through boxed vectors altogether? + _ -> + let n = length l -- we rather force the list than allocate too long a vector with VS.fromList + t = ORB.fromList [n] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l) + in XArray (S.ravel t) toListOuter :: forall a n sh. Storable a => XArray (n : sh) a -> [XArray sh a] toListOuter (XArray arr@(ORS.A (ORG.A _ t))) = @@ -312,12 +321,19 @@ toListOuter (XArray arr@(ORS.A (ORG.A _ t))) = fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a fromList1 ssh l = - let n = length l - in case ssh of - SKnown m :!% _ | fromSNat' m /= n -> - error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++ - "does not match the type (" ++ show (fromSNat' m) ++ ")" - _ -> XArray (S.fromVector [n] (VS.fromListN n l)) + case ssh of + SKnown m :!% _ -> + let n = fromSNat' m + !v = VS.fromListN n l + -- forced to fuse the list away (or at least keep in memory at most one cell) + in if n /= length l + then error $ "Data.Array.Mixed.fromList1: length of list (" ++ show (length l) ++ ")" ++ + "does not match the type (" ++ show n ++ ")" + else XArray (S.fromVector [n] v) + _ -> + let n = length l -- we rather force the list than allocate too long a vector with VS.fromList + v = VS.fromListN n l + in XArray (S.fromVector [n] v) toList1 :: Storable a => XArray '[n] a -> [a] toList1 (XArray arr) = S.toList arr |
