aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/XArray.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/XArray.hs')
-rw-r--r--src/Data/Array/XArray.hs86
1 files changed, 45 insertions, 41 deletions
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs
index 1445ce6..4f5bb08 100644
--- a/src/Data/Array/XArray.hs
+++ b/src/Data/Array/XArray.hs
@@ -17,7 +17,7 @@
module Data.Array.XArray where
import Control.DeepSeq (NFData)
-import Control.Monad (foldM)
+import Control.Monad (foldM_, foldM)
import Control.Monad.ST
import Data.Array.Internal qualified as OI
import Data.Array.Internal.RankedG qualified as ORG
@@ -26,7 +26,7 @@ import Data.Array.RankedS qualified as S
import Data.Coerce
import Data.Foldable (toList)
import Data.Kind
-import Data.List.NonEmpty (NonEmpty)
+import Data.List.NonEmpty (NonEmpty(..))
import Data.Proxy
import Data.Type.Equality
import Data.Type.Ord
@@ -62,6 +62,7 @@ shape = \ssh (XArray arr) -> go ssh (S.shapeL arr)
go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l
go _ _ = error "Invalid shapeL"
+{-# INLINEABLE fromVector #-}
fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a
fromVector sh v
| Dict <- lemKnownNatRank sh
@@ -87,7 +88,7 @@ cast ssh1 sh2 ssh' (XArray arr)
| Refl <- lemRankApp ssh1 ssh'
, Refl <- lemRankApp (ssxFromShX sh2) ssh'
= let arrsh :: IShX sh1
- (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr))
+ arrsh = shxTakeSSX (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr))
in if shxToList arrsh == shxToList sh2
then XArray arr
else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")"
@@ -184,7 +185,7 @@ rerank :: forall sh sh1 sh2 a b.
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b
rerank ssh ssh1 ssh2 f xarr@(XArray arr)
| Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)
- = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr)
+ = let sh = shxTakeSSX (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr)
in if 0 `elem` shxToList sh
then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) [])
else case () of
@@ -211,7 +212,7 @@ rerank2 :: forall sh sh1 sh2 a b c.
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c
rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2)
| Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)
- = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1)
+ = let sh = shxTakeSSX (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1)
in if 0 `elem` shxToList sh
then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) [])
else case () of
@@ -274,14 +275,14 @@ sumInner :: forall sh sh' a. (Storable a, NumElt a)
=> StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a
sumInner ssh ssh' arr
| Refl <- lemAppNil @sh
- = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
+ = let sh' = shxDropSSX @sh @sh' ssh (shape (ssxAppend ssh ssh') arr)
sh'F = shxFlatten sh' :$% ZSX
ssh'F = ssxFromShX sh'F
go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a
go (XArray arr')
| Refl <- lemRankApp ssh ssh'F
- , let sn = listxRank (let StaticShX l = ssh in l)
+ , let sn = ssxRank ssh
= XArray (liftO1 (numEltSum1Inner sn) arr')
in go $
@@ -294,7 +295,7 @@ sumOuter :: forall sh sh' a. (Storable a, NumElt a)
=> StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a
sumOuter ssh ssh' arr
| Refl <- lemAppNil @sh
- = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
+ = let sh = shxTakeSSX (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
shF = shxFlatten sh :$% ZSX
in sumInner ssh' (ssxFromShX shF) $
transpose2 (ssxFromShX shF) ssh' $
@@ -305,50 +306,48 @@ sumOuter ssh ssh' arr
-- the list's spine must be fully materialised to compute its length before
-- constructing the array. The list can't be empty (not enough information
-- in the given shape to guess the shape of the empty array, in general).
-fromListOuter :: forall n sh a. Storable a
- => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
-fromListOuter ssh l
- | Dict <- lemKnownNatRankSSX (ssxTail ssh)
- , let l' = coerce @[XArray sh a] @[S.Array (Rank sh) a] l
- = case ssh of
- _ :!% ZKX ->
- fromList1 ssh (map S.unScalar l')
- SKnown m :!% _ ->
- let n = fromSNat' m
- in XArray (ravelOuterN n l')
- _ ->
- let n = length l
- in XArray (ravelOuterN n l')
+{-# INLINE fromListOuterSN #-}
+fromListOuterSN :: forall n sh a. Storable a
+ => SNat n -> IShX sh -> NonEmpty (XArray sh a) -> XArray (Just n : sh) a
+fromListOuterSN m sh l
+ | Dict <- lemKnownNatRank sh
+ , let l' = coerce @(NonEmpty (XArray sh a)) @(NonEmpty (S.Array (Rank sh) a)) l
+ = case sh of
+ ZSX -> fromList1SN m (map S.unScalar (toList l'))
+ _ -> XArray (ravelOuterN (fromSNat' m) l')
-- | This checks that the list has the given length and that all shapes in the
-- list are equal. The list must be non-empty, and is streamed.
+{-# INLINEABLE ravelOuterN #-}
ravelOuterN :: (KnownNat k, Storable a)
- => Int -> [S.Array k a] -> S.Array (1 + k) a
+ => Int -> NonEmpty (S.Array k a) -> S.Array (1 + k) a
ravelOuterN 0 _ = error "ravelOuterN: N == 0"
-ravelOuterN _ [] = error "ravelOuterN: empty list"
-ravelOuterN k as@(a0 : _) = runST $ do
+ravelOuterN k as@(a0 :| _) = runST $ do
let sh0 = S.shapeL a0
len = product sh0
vecSize = k * len
vec <- VSM.unsafeNew vecSize
- let f !n a =
+ let f !n (ORS.A (ORG.A sht t)) =
if | n >= k ->
error $ "ravelOuterN: list too long " ++ show (n, k)
-- if we do this check just once at the end, we may
-- crash instead of producing an accurate error message
- | S.shapeL a == sh0 -> do
- VS.unsafeCopy (VSM.slice (n * len) len vec) (S.toVector a)
- return $! n + 1
+ | sht == sh0 -> do
+ let g off el = do
+ VS.unsafeCopy (VSM.slice off (VS.length el) vec) el
+ return $! off + VS.length el
+ foldM_ g (n * len) (OI.toVectorListT sht t)
+ return $! n + 1
| otherwise ->
- error $ "ravelOuterN: unequal shapes " ++ show (S.shapeL a, sh0)
+ error $ "ravelOuterN: unequal shapes " ++ show (sht, sh0)
nFinal <- foldM f 0 as
if nFinal == k
then S.fromVector (k : sh0) <$> VS.unsafeFreeze vec
else error $ "ravelOuterN: list too short " ++ show (nFinal, k)
toListOuter :: forall a n sh. Storable a => XArray (n : sh) a -> [XArray sh a]
-toListOuter (XArray arr@(ORS.A (ORG.A _ t))) =
- case S.shapeL arr of
+toListOuter (XArray arr@(ORS.A (ORG.A shArr t))) =
+ case shArr of
[] -> error "impossible"
0 : _ -> []
-- using orthotope's functions here would entail using rerank, which is slow, so we don't
@@ -358,15 +357,20 @@ toListOuter (XArray arr@(ORS.A (ORG.A _ t))) =
-- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown,
-- the list's spine must be fully materialised to compute its length before
-- constructing the array.
-fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a
-fromList1 ssh l =
- case ssh of
- SKnown m :!% _ ->
- let n = fromSNat' m -- do length check and vector construction simultaneously so that l can be streamed
- in XArray (S.fromVector [n] (VGC.fromListNChecked n l))
- _ ->
- let n = length l -- avoid S.fromList because it takes a length _and_ does another length check itself
- in XArray (S.fromVector [n] (VS.fromListN n l))
+{-# INLINE fromList1 #-}
+fromList1 :: Storable a => [a] -> XArray '[Nothing] a
+fromList1 l =
+ let n = length l -- avoid S.fromList because it takes a length _and_ does another length check itself
+ in XArray (S.fromVector [n] (VS.fromListN n l))
+
+-- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown,
+-- the list's spine must be fully materialised to compute its length before
+-- constructing the array.
+{-# INLINE fromList1SN #-}
+fromList1SN :: Storable a => SNat n -> [a] -> XArray '[Just n] a
+fromList1SN m l =
+ let n = fromSNat' m -- do length check and vector construction simultaneously so that l can be streamed
+ in XArray (S.fromVector [n] (VGC.fromListNChecked n l))
toList1 :: Storable a => XArray '[n] a -> [a]
toList1 (XArray arr) = S.toList arr