aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/XArray.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/XArray.hs')
-rw-r--r--src/Data/Array/XArray.hs104
1 files changed, 84 insertions, 20 deletions
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs
index 9776e21..6389e67 100644
--- a/src/Data/Array/XArray.hs
+++ b/src/Data/Array/XArray.hs
@@ -1,8 +1,11 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
@@ -14,10 +17,11 @@
module Data.Array.XArray where
import Control.DeepSeq (NFData)
+import Control.Monad (foldM)
+import Control.Monad.ST
import Data.Array.Internal qualified as OI
import Data.Array.Internal.RankedG qualified as ORG
import Data.Array.Internal.RankedS qualified as ORS
-import Data.Array.Ranked qualified as ORB
import Data.Array.RankedS qualified as S
import Data.Coerce
import Data.Foldable (toList)
@@ -26,10 +30,15 @@ import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
import Data.Type.Equality
import Data.Type.Ord
+import Data.Vector.Generic.Checked qualified as VGC
import Data.Vector.Storable qualified as VS
+import Data.Vector.Storable.Mutable qualified as VSM
import Foreign.Storable (Storable)
import GHC.Generics (Generic)
import GHC.TypeLits
+#if !MIN_VERSION_GLASGOW_HASKELL(9,8,0,0)
+import Unsafe.Coerce (unsafeCoerce)
+#endif
import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed.Shape
@@ -108,15 +117,23 @@ generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh)
-- XArray . S.fromVector (shxShapeL sh)
-- <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh)
+{-# INLINEABLE indexPartial #-}
indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a
indexPartial (XArray arr) ZIX = XArray arr
indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx
+{- Strangely, this increases allocation and there's no noticeable speedup:
+indexPartial (XArray (ORS.A (ORG.A sh t))) ix =
+ let linear = OI.offset t + sum (zipWith (*) (ixxToList ix) (OI.strides t))
+ len = ixxLength ix
+ in XArray (ORS.A (ORG.A (drop len sh)
+ OI.T{ OI.strides = drop len (OI.strides t)
+ , OI.offset = linear
+ , OI.values = OI.values t })) -}
+{-# INLINEABLE index #-}
index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a
-index xarr i
- | Refl <- lemAppNil @sh
- = let XArray arr' = indexPartial xarr i :: XArray '[] a
- in S.unScalar arr'
+index (XArray (ORS.A (ORG.A _ t))) i =
+ OI.values t VS.! (OI.offset t + sum (zipWith (*) (toList i) (OI.strides t)))
append :: forall n m sh a. Storable a
=> StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a
@@ -217,7 +234,12 @@ transpose ssh perm (XArray arr)
, Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh)
, Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm
, Refl <- lemRankDropLen ssh perm
+#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0)
= XArray (S.transpose (permToList' perm) arr)
+#else
+ = XArray (unsafeCoerce (S.transpose (permToList' perm) arr))
+#endif
+
-- | The list argument gives indices into the original dimension list.
--
@@ -243,7 +265,7 @@ transpose2 ssh1 ssh2 (XArray arr)
, Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1)
, Refl <- lemRankAppComm ssh1 ssh2
, let n1 = ssxLength ssh1
- = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr)
+ = XArray (S.transpose (ssxIotaFrom ssh2 n1 ++ ssxIotaFrom ssh1 0) arr)
sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a
sumFull _ (XArray arr) =
@@ -283,30 +305,72 @@ sumOuter ssh ssh' arr
reshapePartial ssh ssh' shF $
arr
+-- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown,
+-- the list's spine must be fully materialised to compute its length before
+-- constructing the array. The list can't be empty (not enough information
+-- in the given shape to guess the shape of the empty array, in general).
fromListOuter :: forall n sh a. Storable a
=> StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
fromListOuter ssh l
- | Dict <- lemKnownNatRankSSX ssh
+ | Dict <- lemKnownNatRankSSX (ssxTail ssh)
+ , let l' = coerce @[XArray sh a] @[S.Array (Rank sh) a] l
= case ssh of
- SKnown m :!% _ | fromSNat' m /= length l ->
- error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++
- "does not match the type (" ++ show (fromSNat' m) ++ ")"
- _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l)))
+ _ :!% ZKX ->
+ fromList1 ssh (map S.unScalar l')
+ SKnown m :!% _ ->
+ let n = fromSNat' m
+ in XArray (ravelOuterN n l')
+ _ ->
+ let n = length l
+ in XArray (ravelOuterN n l')
+
+-- | This checks that the list has the given length and that all shapes in the
+-- list are equal. The list must be non-empty, and is streamed.
+ravelOuterN :: (KnownNat k, Storable a)
+ => Int -> [S.Array k a] -> S.Array (1 + k) a
+ravelOuterN 0 _ = error "ravelOuterN: N == 0"
+ravelOuterN _ [] = error "ravelOuterN: empty list"
+ravelOuterN k as@(a0 : _) = runST $ do
+ let sh0 = S.shapeL a0
+ len = product sh0
+ vecSize = k * len
+ vec <- VSM.unsafeNew vecSize
+ let f !n a =
+ if | n >= k ->
+ error $ "ravelOuterN: list too long " ++ show (n, k)
+ -- if we do this check just once at the end, we may
+ -- crash instead of producing an accurate error message
+ | S.shapeL a == sh0 -> do
+ VS.unsafeCopy (VSM.slice (n * len) len vec) (S.toVector a)
+ return $! n + 1
+ | otherwise ->
+ error $ "ravelOuterN: unequal shapes " ++ show (S.shapeL a, sh0)
+ nFinal <- foldM f 0 as
+ if nFinal == k
+ then S.fromVector (k : sh0) <$> VS.unsafeFreeze vec
+ else error $ "ravelOuterN: list too short " ++ show (nFinal, k)
-toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a]
-toListOuter (XArray arr) =
+toListOuter :: forall a n sh. Storable a => XArray (n : sh) a -> [XArray sh a]
+toListOuter (XArray arr@(ORS.A (ORG.A _ t))) =
case S.shapeL arr of
+ [] -> error "impossible"
0 : _ -> []
- _ -> coerce (ORB.toList (S.unravel arr))
+ -- using orthotope's functions here would entail using rerank, which is slow, so we don't
+ [_] | Refl <- (unsafeCoerceRefl :: sh :~: '[]) -> coerce (map S.scalar $ S.toList arr)
+ n : sh -> coerce $ map (ORG.A sh . OI.indexT t) [0 .. n - 1]
+-- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown,
+-- the list's spine must be fully materialised to compute its length before
+-- constructing the array.
fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a
fromList1 ssh l =
- let n = length l
- in case ssh of
- SKnown m :!% _ | fromSNat' m /= n ->
- error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++
- "does not match the type (" ++ show (fromSNat' m) ++ ")"
- _ -> XArray (S.fromVector [n] (VS.fromListN n l))
+ case ssh of
+ SKnown m :!% _ ->
+ let n = fromSNat' m -- do length check and vector construction simultaneously so that l can be streamed
+ in XArray (S.fromVector [n] (VGC.fromListNChecked n l))
+ _ ->
+ let n = length l -- avoid S.fromList because it takes a length _and_ does another length check itself
+ in XArray (S.fromVector [n] (VS.fromListN n l))
toList1 :: Storable a => XArray '[n] a -> [a]
toList1 (XArray arr) = S.toList arr