diff options
Diffstat (limited to 'src/Data/Array/XArray.hs')
| -rw-r--r-- | src/Data/Array/XArray.hs | 86 |
1 files changed, 45 insertions, 41 deletions
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index 1445ce6..4f5bb08 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -17,7 +17,7 @@ module Data.Array.XArray where import Control.DeepSeq (NFData) -import Control.Monad (foldM) +import Control.Monad (foldM_, foldM) import Control.Monad.ST import Data.Array.Internal qualified as OI import Data.Array.Internal.RankedG qualified as ORG @@ -26,7 +26,7 @@ import Data.Array.RankedS qualified as S import Data.Coerce import Data.Foldable (toList) import Data.Kind -import Data.List.NonEmpty (NonEmpty) +import Data.List.NonEmpty (NonEmpty(..)) import Data.Proxy import Data.Type.Equality import Data.Type.Ord @@ -62,6 +62,7 @@ shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l go _ _ = error "Invalid shapeL" +{-# INLINEABLE fromVector #-} fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a fromVector sh v | Dict <- lemKnownNatRank sh @@ -87,7 +88,7 @@ cast ssh1 sh2 ssh' (XArray arr) | Refl <- lemRankApp ssh1 ssh' , Refl <- lemRankApp (ssxFromShX sh2) ssh' = let arrsh :: IShX sh1 - (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) + arrsh = shxTakeSSX (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) in if shxToList arrsh == shxToList sh2 then XArray arr else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")" @@ -184,7 +185,7 @@ rerank :: forall sh sh1 sh2 a b. -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b rerank ssh ssh1 ssh2 f xarr@(XArray arr) | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) - = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) + = let sh = shxTakeSSX (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) in if 0 `elem` shxToList sh then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) else case () of @@ -211,7 +212,7 @@ rerank2 :: forall sh sh1 sh2 a b c. -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2) | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) - = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) + = let sh = shxTakeSSX (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) in if 0 `elem` shxToList sh then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) else case () of @@ -274,14 +275,14 @@ sumInner :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a sumInner ssh ssh' arr | Refl <- lemAppNil @sh - = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + = let sh' = shxDropSSX @sh @sh' ssh (shape (ssxAppend ssh ssh') arr) sh'F = shxFlatten sh' :$% ZSX ssh'F = ssxFromShX sh'F go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a go (XArray arr') | Refl <- lemRankApp ssh ssh'F - , let sn = listxRank (let StaticShX l = ssh in l) + , let sn = ssxRank ssh = XArray (liftO1 (numEltSum1Inner sn) arr') in go $ @@ -294,7 +295,7 @@ sumOuter :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a sumOuter ssh ssh' arr | Refl <- lemAppNil @sh - = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + = let sh = shxTakeSSX (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) shF = shxFlatten sh :$% ZSX in sumInner ssh' (ssxFromShX shF) $ transpose2 (ssxFromShX shF) ssh' $ @@ -305,50 +306,48 @@ sumOuter ssh ssh' arr -- the list's spine must be fully materialised to compute its length before -- constructing the array. The list can't be empty (not enough information -- in the given shape to guess the shape of the empty array, in general). -fromListOuter :: forall n sh a. Storable a - => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a -fromListOuter ssh l - | Dict <- lemKnownNatRankSSX (ssxTail ssh) - , let l' = coerce @[XArray sh a] @[S.Array (Rank sh) a] l - = case ssh of - _ :!% ZKX -> - fromList1 ssh (map S.unScalar l') - SKnown m :!% _ -> - let n = fromSNat' m - in XArray (ravelOuterN n l') - _ -> - let n = length l - in XArray (ravelOuterN n l') +{-# INLINE fromListOuterSN #-} +fromListOuterSN :: forall n sh a. Storable a + => SNat n -> IShX sh -> NonEmpty (XArray sh a) -> XArray (Just n : sh) a +fromListOuterSN m sh l + | Dict <- lemKnownNatRank sh + , let l' = coerce @(NonEmpty (XArray sh a)) @(NonEmpty (S.Array (Rank sh) a)) l + = case sh of + ZSX -> fromList1SN m (map S.unScalar (toList l')) + _ -> XArray (ravelOuterN (fromSNat' m) l') -- | This checks that the list has the given length and that all shapes in the -- list are equal. The list must be non-empty, and is streamed. +{-# INLINEABLE ravelOuterN #-} ravelOuterN :: (KnownNat k, Storable a) - => Int -> [S.Array k a] -> S.Array (1 + k) a + => Int -> NonEmpty (S.Array k a) -> S.Array (1 + k) a ravelOuterN 0 _ = error "ravelOuterN: N == 0" -ravelOuterN _ [] = error "ravelOuterN: empty list" -ravelOuterN k as@(a0 : _) = runST $ do +ravelOuterN k as@(a0 :| _) = runST $ do let sh0 = S.shapeL a0 len = product sh0 vecSize = k * len vec <- VSM.unsafeNew vecSize - let f !n a = + let f !n (ORS.A (ORG.A sht t)) = if | n >= k -> error $ "ravelOuterN: list too long " ++ show (n, k) -- if we do this check just once at the end, we may -- crash instead of producing an accurate error message - | S.shapeL a == sh0 -> do - VS.unsafeCopy (VSM.slice (n * len) len vec) (S.toVector a) - return $! n + 1 + | sht == sh0 -> do + let g off el = do + VS.unsafeCopy (VSM.slice off (VS.length el) vec) el + return $! off + VS.length el + foldM_ g (n * len) (OI.toVectorListT sht t) + return $! n + 1 | otherwise -> - error $ "ravelOuterN: unequal shapes " ++ show (S.shapeL a, sh0) + error $ "ravelOuterN: unequal shapes " ++ show (sht, sh0) nFinal <- foldM f 0 as if nFinal == k then S.fromVector (k : sh0) <$> VS.unsafeFreeze vec else error $ "ravelOuterN: list too short " ++ show (nFinal, k) toListOuter :: forall a n sh. Storable a => XArray (n : sh) a -> [XArray sh a] -toListOuter (XArray arr@(ORS.A (ORG.A _ t))) = - case S.shapeL arr of +toListOuter (XArray arr@(ORS.A (ORG.A shArr t))) = + case shArr of [] -> error "impossible" 0 : _ -> [] -- using orthotope's functions here would entail using rerank, which is slow, so we don't @@ -358,15 +357,20 @@ toListOuter (XArray arr@(ORS.A (ORG.A _ t))) = -- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown, -- the list's spine must be fully materialised to compute its length before -- constructing the array. -fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a -fromList1 ssh l = - case ssh of - SKnown m :!% _ -> - let n = fromSNat' m -- do length check and vector construction simultaneously so that l can be streamed - in XArray (S.fromVector [n] (VGC.fromListNChecked n l)) - _ -> - let n = length l -- avoid S.fromList because it takes a length _and_ does another length check itself - in XArray (S.fromVector [n] (VS.fromListN n l)) +{-# INLINE fromList1 #-} +fromList1 :: Storable a => [a] -> XArray '[Nothing] a +fromList1 l = + let n = length l -- avoid S.fromList because it takes a length _and_ does another length check itself + in XArray (S.fromVector [n] (VS.fromListN n l)) + +-- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown, +-- the list's spine must be fully materialised to compute its length before +-- constructing the array. +{-# INLINE fromList1SN #-} +fromList1SN :: Storable a => SNat n -> [a] -> XArray '[Just n] a +fromList1SN m l = + let n = fromSNat' m -- do length check and vector construction simultaneously so that l can be streamed + in XArray (S.fromVector [n] (VGC.fromListNChecked n l)) toList1 :: Storable a => XArray '[n] a -> [a] toList1 (XArray arr) = S.toList arr |
