diff options
Diffstat (limited to 'src/Data/Array/XArray.hs')
| -rw-r--r-- | src/Data/Array/XArray.hs | 104 |
1 files changed, 84 insertions, 20 deletions
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index 9776e21..6389e67 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -1,8 +1,11 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE MultiWayIf #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneKindSignatures #-} @@ -14,10 +17,11 @@ module Data.Array.XArray where import Control.DeepSeq (NFData) +import Control.Monad (foldM) +import Control.Monad.ST import Data.Array.Internal qualified as OI import Data.Array.Internal.RankedG qualified as ORG import Data.Array.Internal.RankedS qualified as ORS -import Data.Array.Ranked qualified as ORB import Data.Array.RankedS qualified as S import Data.Coerce import Data.Foldable (toList) @@ -26,10 +30,15 @@ import Data.List.NonEmpty (NonEmpty) import Data.Proxy import Data.Type.Equality import Data.Type.Ord +import Data.Vector.Generic.Checked qualified as VGC import Data.Vector.Storable qualified as VS +import Data.Vector.Storable.Mutable qualified as VSM import Foreign.Storable (Storable) import GHC.Generics (Generic) import GHC.TypeLits +#if !MIN_VERSION_GLASGOW_HASKELL(9,8,0,0) +import Unsafe.Coerce (unsafeCoerce) +#endif import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed.Shape @@ -108,15 +117,23 @@ generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh) -- XArray . S.fromVector (shxShapeL sh) -- <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh) +{-# INLINEABLE indexPartial #-} indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a indexPartial (XArray arr) ZIX = XArray arr indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx +{- Strangely, this increases allocation and there's no noticeable speedup: +indexPartial (XArray (ORS.A (ORG.A sh t))) ix = + let linear = OI.offset t + sum (zipWith (*) (ixxToList ix) (OI.strides t)) + len = ixxLength ix + in XArray (ORS.A (ORG.A (drop len sh) + OI.T{ OI.strides = drop len (OI.strides t) + , OI.offset = linear + , OI.values = OI.values t })) -} +{-# INLINEABLE index #-} index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a -index xarr i - | Refl <- lemAppNil @sh - = let XArray arr' = indexPartial xarr i :: XArray '[] a - in S.unScalar arr' +index (XArray (ORS.A (ORG.A _ t))) i = + OI.values t VS.! (OI.offset t + sum (zipWith (*) (toList i) (OI.strides t))) append :: forall n m sh a. Storable a => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a @@ -217,7 +234,12 @@ transpose ssh perm (XArray arr) , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh) , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm , Refl <- lemRankDropLen ssh perm +#if MIN_VERSION_GLASGOW_HASKELL(9,8,0,0) = XArray (S.transpose (permToList' perm) arr) +#else + = XArray (unsafeCoerce (S.transpose (permToList' perm) arr)) +#endif + -- | The list argument gives indices into the original dimension list. -- @@ -243,7 +265,7 @@ transpose2 ssh1 ssh2 (XArray arr) , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) , Refl <- lemRankAppComm ssh1 ssh2 , let n1 = ssxLength ssh1 - = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) + = XArray (S.transpose (ssxIotaFrom ssh2 n1 ++ ssxIotaFrom ssh1 0) arr) sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a sumFull _ (XArray arr) = @@ -283,30 +305,72 @@ sumOuter ssh ssh' arr reshapePartial ssh ssh' shF $ arr +-- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown, +-- the list's spine must be fully materialised to compute its length before +-- constructing the array. The list can't be empty (not enough information +-- in the given shape to guess the shape of the empty array, in general). fromListOuter :: forall n sh a. Storable a => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a fromListOuter ssh l - | Dict <- lemKnownNatRankSSX ssh + | Dict <- lemKnownNatRankSSX (ssxTail ssh) + , let l' = coerce @[XArray sh a] @[S.Array (Rank sh) a] l = case ssh of - SKnown m :!% _ | fromSNat' m /= length l -> - error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++ - "does not match the type (" ++ show (fromSNat' m) ++ ")" - _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) + _ :!% ZKX -> + fromList1 ssh (map S.unScalar l') + SKnown m :!% _ -> + let n = fromSNat' m + in XArray (ravelOuterN n l') + _ -> + let n = length l + in XArray (ravelOuterN n l') + +-- | This checks that the list has the given length and that all shapes in the +-- list are equal. The list must be non-empty, and is streamed. +ravelOuterN :: (KnownNat k, Storable a) + => Int -> [S.Array k a] -> S.Array (1 + k) a +ravelOuterN 0 _ = error "ravelOuterN: N == 0" +ravelOuterN _ [] = error "ravelOuterN: empty list" +ravelOuterN k as@(a0 : _) = runST $ do + let sh0 = S.shapeL a0 + len = product sh0 + vecSize = k * len + vec <- VSM.unsafeNew vecSize + let f !n a = + if | n >= k -> + error $ "ravelOuterN: list too long " ++ show (n, k) + -- if we do this check just once at the end, we may + -- crash instead of producing an accurate error message + | S.shapeL a == sh0 -> do + VS.unsafeCopy (VSM.slice (n * len) len vec) (S.toVector a) + return $! n + 1 + | otherwise -> + error $ "ravelOuterN: unequal shapes " ++ show (S.shapeL a, sh0) + nFinal <- foldM f 0 as + if nFinal == k + then S.fromVector (k : sh0) <$> VS.unsafeFreeze vec + else error $ "ravelOuterN: list too short " ++ show (nFinal, k) -toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a] -toListOuter (XArray arr) = +toListOuter :: forall a n sh. Storable a => XArray (n : sh) a -> [XArray sh a] +toListOuter (XArray arr@(ORS.A (ORG.A _ t))) = case S.shapeL arr of + [] -> error "impossible" 0 : _ -> [] - _ -> coerce (ORB.toList (S.unravel arr)) + -- using orthotope's functions here would entail using rerank, which is slow, so we don't + [_] | Refl <- (unsafeCoerceRefl :: sh :~: '[]) -> coerce (map S.scalar $ S.toList arr) + n : sh -> coerce $ map (ORG.A sh . OI.indexT t) [0 .. n - 1] +-- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown, +-- the list's spine must be fully materialised to compute its length before +-- constructing the array. fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a fromList1 ssh l = - let n = length l - in case ssh of - SKnown m :!% _ | fromSNat' m /= n -> - error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++ - "does not match the type (" ++ show (fromSNat' m) ++ ")" - _ -> XArray (S.fromVector [n] (VS.fromListN n l)) + case ssh of + SKnown m :!% _ -> + let n = fromSNat' m -- do length check and vector construction simultaneously so that l can be streamed + in XArray (S.fromVector [n] (VGC.fromListNChecked n l)) + _ -> + let n = length l -- avoid S.fromList because it takes a length _and_ does another length check itself + in XArray (S.fromVector [n] (VS.fromListN n l)) toList1 :: Storable a => XArray '[n] a -> [a] toList1 (XArray arr) = S.toList arr |
