aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/XArray.hs44
1 files changed, 38 insertions, 6 deletions
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs
index ee83654..64bbb0c 100644
--- a/src/Data/Array/XArray.hs
+++ b/src/Data/Array/XArray.hs
@@ -1,9 +1,11 @@
+{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
@@ -15,10 +17,11 @@
module Data.Array.XArray where
import Control.DeepSeq (NFData)
+import Control.Monad (foldM)
+import Control.Monad.ST
import Data.Array.Internal qualified as OI
import Data.Array.Internal.RankedG qualified as ORG
import Data.Array.Internal.RankedS qualified as ORS
-import Data.Array.Ranked qualified as ORB
import Data.Array.RankedS qualified as S
import Data.Coerce
import Data.Foldable (toList)
@@ -27,9 +30,9 @@ import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
import Data.Type.Equality
import Data.Type.Ord
-import Data.Vector qualified as V
import Data.Vector.Generic.Checked qualified as VGC
import Data.Vector.Storable qualified as VS
+import Data.Vector.Storable.Mutable qualified as VSM
import Foreign.Storable (Storable)
import GHC.Generics (Generic)
import GHC.TypeLits
@@ -237,6 +240,7 @@ transpose ssh perm (XArray arr)
= XArray (unsafeCoerce (S.transpose (permToList' perm) arr))
#endif
+
-- | The list argument gives indices into the original dimension list.
--
-- The permutation (the list) must have length <= @n@. If it is longer, this
@@ -303,21 +307,49 @@ sumOuter ssh ssh' arr
-- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown,
-- the list's spine must be fully materialised to compute its length before
--- constructing the array.
+-- constructing the array. The list can't be empty (not enough information
+-- in the given shape to guess the shape of the empty array, in general).
fromListOuter :: forall n sh a. Storable a
=> StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
fromListOuter ssh l
- | Dict <- lemKnownNatRankSSX ssh
+ | Dict <- lemKnownNatRankSSX (ssxTail ssh)
, let l' = coerce @[XArray sh a] @[S.Array (Rank sh) a] l
= case ssh of
_ :!% ZKX ->
fromList1 ssh (map S.unScalar l')
SKnown m :!% _ ->
let n = fromSNat' m
- in XArray (S.ravel (ORB.fromVector [n] (VGC.fromListNChecked n l')))
+ in XArray (ravelOuterN n l')
_ ->
let n = length l
- in XArray (S.ravel (ORB.fromVector [n] (V.fromListN n l')))
+ in XArray (ravelOuterN n l')
+
+-- This checks that the list has the given length and that each shape
+-- on the list equals the given shape. The list is streamed.
+-- The list can't be empty.
+ravelOuterN :: (KnownNat k, Storable a)
+ => Int -> [S.Array k a] -> S.Array (1 + k) a
+ravelOuterN 0 _ = error "ravelOuterN: N == 0"
+ravelOuterN _ [] = error "ravelOuterN: empty list"
+ravelOuterN k as@(a0 : _) = runST $ do
+ let sh0 = S.shapeL a0
+ len = product sh0
+ vecSize = k * len
+ vec <- VSM.unsafeNew vecSize
+ let f !n a =
+ if | n >= k ->
+ error $ "ravelOuterN: list too long " ++ show (n, k)
+ -- if we do this check just once at the end, we may
+ -- crash instead of producing an accurate error message
+ | S.shapeL a == sh0 -> do
+ VS.unsafeCopy (VSM.slice (n * len) len vec) (S.toVector a)
+ return $! n + 1
+ | otherwise ->
+ error $ "ravelOuterN: unequal shapes " ++ show (S.shapeL a, sh0)
+ nFinal <- foldM f 0 as
+ if nFinal == k
+ then S.fromVector (k : sh0) <$> VS.unsafeFreeze vec
+ else error $ "ravelOuterN: list too short " ++ show (nFinal, k)
toListOuter :: forall a n sh. Storable a => XArray (n : sh) a -> [XArray sh a]
toListOuter (XArray arr@(ORS.A (ORG.A _ t))) =