aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
-rw-r--r--src/Data/Array/Nested/Mixed.hs170
1 files changed, 125 insertions, 45 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 182943d..39f00fa 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -23,12 +23,14 @@ module Data.Array.Nested.Mixed where
import Prelude hiding (mconcat)
import Control.DeepSeq (NFData(..))
-import Control.Monad (forM_, when)
+import Control.Monad (foldM_, forM_, when)
import Control.Monad.ST
+import Data.Array.Internal qualified as OI
+import Data.Array.Internal.RankedG qualified as ORG
+import Data.Array.Internal.RankedS qualified as ORS
import Data.Array.RankedS qualified as S
import Data.Bifunctor (bimap)
import Data.Coerce
-import Data.Foldable (toList)
import Data.Int
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty(..))
@@ -39,6 +41,7 @@ import Data.Vector.Storable qualified as VS
import Data.Vector.Storable.Mutable qualified as VSM
import Foreign.C.Types (CInt)
import Foreign.Storable (Storable)
+import Foreign.Storable qualified as Storable
import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
import GHC.TypeLits
@@ -237,11 +240,13 @@ instance Elt a => NFData (Mixed sh a) where
rnf = mrnf
+{-# INLINE mliftNumElt1 #-}
mliftNumElt1 :: (PrimElt a, PrimElt b)
=> (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b)
-> Mixed sh a -> Mixed sh b
mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr))
+{-# INLINE mliftNumElt2 #-}
mliftNumElt2 :: (PrimElt a, PrimElt b, PrimElt c)
=> (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b -> S.Array (Rank sh) c)
-> Mixed sh a -> Mixed sh b -> Mixed sh c
@@ -310,7 +315,7 @@ class Elt a where
-- | See 'mfromListOuter'. If the list does not have the given length, a
-- runtime error is thrown. 'mfromListPrimSN' is faster if applicable.
- mfromListOuterSN :: forall sh n. SNat n -> NonEmpty (Mixed sh a) -> Mixed (Just n : sh) a
+ mfromListOuterSN :: forall sh n. SNat n -> NonEmpty (Mixed sh a) -> Mixed (Just n : sh) a
mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a]
@@ -355,6 +360,9 @@ class Elt a where
-- | Tree giving the shape of every array component.
type ShapeTree a
+ -- | Produces an internal representation of a tree of shapes of (potentially)
+ -- nested arrays. If the argument is an array, it requires that the array
+ -- is not empty (it's guaranteed to crash early otherwise).
mshapeTree :: a -> ShapeTree a
mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool
@@ -367,17 +375,21 @@ class Elt a where
-- this mixed array.
marrayStrides :: Mixed sh a -> Bag [Int]
- -- | Given the shape of this array, an index and a value, write the value at
+ -- | Given a linear index and a value, write the value at
-- that index in the vectors.
- mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
+ mvecsWriteLinear :: Int -> a -> MixedVecs s sh a -> ST s ()
- -- | Given the shape of this array, an index and a value, write the value at
+ -- | Given a linear index and a value, write the value at
-- that index in the vectors.
- mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+ mvecsWritePartialLinear :: Proxy sh -> Int -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
-- | Given the shape of this array, finalise the vectors into 'XArray's.
mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
+ -- | 'mvecsFreeze' but without copying the mutable vectors before freezing
+ -- them. If the mutable vectors are mutated after this function, referential
+ -- transparency is broken.
+ mvecsUnsafeFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
-- | Element types for which we have evidence of the (static part of the) shape
-- in a type class constraint. Compare the instance contexts of the instances
@@ -393,11 +405,18 @@ class Elt a => KnownElt a where
-- this vector and an example for the contents.
mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a)
+ -- | Create initialised vectors for this array type, given the shape of
+ -- this vector and the chosen element.
+ mvecsReplicate :: IShX sh -> a -> ST s (MixedVecs s sh a)
+
mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a)
-- Arrays of scalars are basically just arrays of scalars.
instance Storable a => Elt (Primitive a) where
+ -- Somehow, INLINE here can increase allocation with GHC 9.14.1.
+ -- Maybe that happens in void instances such as @Primitive ()@.
+ {-# INLINEABLE mshape #-}
mshape (M_Primitive sh _) = sh
{-# INLINEABLE mindex #-}
mindex (M_Primitive _ a) i = Primitive (X.index a i)
@@ -405,10 +424,11 @@ instance Storable a => Elt (Primitive a) where
mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i)
mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
mfromListOuterSN sn l@(arr1 :| _) =
- let sh = SKnown sn :$% mshape arr1
- in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l)))
+ let sh = mshape arr1
+ in M_Primitive (SKnown sn :$% sh) (X.fromListOuterSN sn sh ((\(M_Primitive _ a) -> a) <$> l))
mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr)
+ {-# INLINE mlift #-}
mlift :: forall sh1 sh2.
StaticShX sh2
-> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
@@ -419,6 +439,7 @@ instance Storable a => Elt (Primitive a) where
, let result = f ZKX a
= M_Primitive (X.shape ssh2 result) result
+ {-# INLINE mlift2 #-}
mlift2 :: forall sh1 sh2 sh3.
StaticShX sh3
-> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a)
@@ -430,6 +451,7 @@ instance Storable a => Elt (Primitive a) where
, let result = f ZKX a b
= M_Primitive (X.shape ssh3 result) result
+ {-# INLINE mliftL #-}
mliftL :: forall sh1 sh2.
StaticShX sh2
-> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
@@ -464,18 +486,22 @@ instance Storable a => Elt (Primitive a) where
mshapeTreeIsEmpty _ () = False
mshowShapeTree _ () = "()"
marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr)
- mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x
+ mvecsWriteLinear i (Primitive x) (MV_Primitive v) = VSM.write v i x
- -- TODO: this use of toVector is suboptimal
- mvecsWritePartial
+ -- TODO: this use of toVectorListT is suboptimal
+ mvecsWritePartialLinear
:: forall sh' sh s.
- IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
- mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do
+ Proxy sh -> Int -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
+ mvecsWritePartialLinear _ i (M_Primitive sh' arr@(XArray (ORS.A (ORG.A sht t)))) (MV_Primitive v) = do
let arrsh = X.shape (ssxFromShX sh') arr
- offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh))
- VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr)
+ offset = i * shxSize arrsh
+ f off el = do
+ VS.copy (VSM.slice off (VS.length el) v) el
+ return $! off + VS.length el
+ foldM_ f offset (OI.toVectorListT sht t)
mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v
+ mvecsUnsafeFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.unsafeFreeze v
-- [PRIMITIVE ELEMENT TYPES LIST]
deriving via Primitive Bool instance Elt Bool
@@ -492,6 +518,7 @@ deriving via Primitive () instance Elt ()
instance Storable a => KnownElt (Primitive a) where
memptyArrayUnsafe sh = M_Primitive sh (X.empty sh)
mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh)
+ mvecsReplicate sh (Primitive a) = MV_Primitive <$> VSM.replicate (shxSize sh) a
mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0
-- [PRIMITIVE ELEMENT TYPES LIST]
@@ -508,16 +535,22 @@ deriving via Primitive () instance KnownElt ()
-- Arrays of pairs are pairs of arrays.
instance (Elt a, Elt b) => Elt (a, b) where
+ {-# INLINEABLE mshape #-}
mshape (M_Tup2 a _) = mshape a
+ {-# INLINEABLE mindex #-}
mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
+ {-# INLINEABLE mindexPartial #-}
mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
mfromListOuterSN sn l =
M_Tup2 (mfromListOuterSN sn ((\(M_Tup2 x _) -> x) <$> l))
(mfromListOuterSN sn ((\(M_Tup2 _ y) -> y) <$> l))
mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b)
+ {-# INLINE mlift #-}
mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b)
+ {-# INLINE mlift2 #-}
mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y)
+ {-# INLINE mliftL #-}
mliftL ssh2 f =
let unzipT2l [] = ([], [])
unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2)
@@ -542,17 +575,19 @@ instance (Elt a, Elt b) => Elt (a, b) where
mshapeTreeIsEmpty _ (t1, t2) = mshapeTreeIsEmpty (Proxy @a) t1 && mshapeTreeIsEmpty (Proxy @b) t2
mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"
marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b
- mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
- mvecsWrite sh i x a
- mvecsWrite sh i y b
- mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do
- mvecsWritePartial sh i x a
- mvecsWritePartial sh i y b
+ mvecsWriteLinear i (x, y) (MV_Tup2 a b) = do
+ mvecsWriteLinear i x a
+ mvecsWriteLinear i y b
+ mvecsWritePartialLinear proxy i (M_Tup2 x y) (MV_Tup2 a b) = do
+ mvecsWritePartialLinear proxy i x a
+ mvecsWritePartialLinear proxy i y b
mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
+ mvecsUnsafeFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsUnsafeFreeze sh a <*> mvecsUnsafeFreeze sh b
instance (KnownElt a, KnownElt b) => KnownElt (a, b) where
memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh)
mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
+ mvecsReplicate sh (x, y) = MV_Tup2 <$> mvecsReplicate sh x <*> mvecsReplicate sh y
mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b)
-- Arrays of arrays are just arrays, but with more dimensions.
@@ -560,13 +595,16 @@ instance Elt a => Elt (Mixed sh' a) where
-- TODO: this is quadratic in the nesting depth because it repeatedly
-- truncates the shape vector to one a little shorter. Fix with a
-- moverlongShape method, a prefix of which is mshape.
+ {-# INLINEABLE mshape #-}
mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh
mshape (M_Nest sh arr)
- = fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr))
+ = shxTakeSh (Proxy @sh') sh (mshape arr)
+ {-# INLINEABLE mindex #-}
mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a
mindex (M_Nest _ arr) = mindexPartial arr
+ {-# INLINEABLE mindexPartial #-}
mindexPartial :: forall sh1 sh2.
Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
mindexPartial (M_Nest sh arr) i
@@ -581,16 +619,17 @@ instance Elt a => Elt (Mixed sh' a) where
mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr)
+ {-# INLINE mlift #-}
mlift :: forall sh1 sh2.
StaticShX sh2
-> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)
-> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
mlift ssh2 f (M_Nest sh1 arr) =
let result = mlift (ssxAppend ssh2 ssh') f' arr
- (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result)
+ sh2 = shxTakeSSX (Proxy @sh') ssh2 (mshape result)
in M_Nest sh2 result
where
- ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr)))
+ ssh' = ssxFromShX (shxDropSh @sh1 @sh' sh1 (mshape arr))
f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b
f' sshT
@@ -598,16 +637,17 @@ instance Elt a => Elt (Mixed sh' a) where
, Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
= f (ssxAppend ssh' sshT)
+ {-# INLINE mlift2 #-}
mlift2 :: forall sh1 sh2 sh3.
StaticShX sh3
-> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b)
-> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a)
mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) =
let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2
- (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result)
+ sh3 = shxTakeSSX (Proxy @sh') ssh3 (mshape result)
in M_Nest sh3 result
where
- ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1)))
+ ssh' = ssxFromShX (shxDropSh @sh1 @sh' sh1 (mshape arr1))
f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b
f' sshT
@@ -616,16 +656,17 @@ instance Elt a => Elt (Mixed sh' a) where
, Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT)
= f (ssxAppend ssh' sshT)
+ {-# INLINE mliftL #-}
mliftL :: forall sh1 sh2.
StaticShX sh2
-> (forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray (sh1 ++ shT) b) -> NonEmpty (XArray (sh2 ++ shT) b))
-> NonEmpty (Mixed sh1 (Mixed sh' a)) -> NonEmpty (Mixed sh2 (Mixed sh' a))
mliftL ssh2 f l@(M_Nest sh1 arr1 :| _) =
let result = mliftL (ssxAppend ssh2 ssh') f' (fmap (\(M_Nest _ arr) -> arr) l)
- (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result))
+ sh2 = shxTakeSSX (Proxy @sh') ssh2 (mshape (NE.head result))
in fmap (M_Nest sh2) result
where
- ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1)))
+ ssh' = ssxFromShX (shxDropSh @sh1 @sh' sh1 (mshape arr1))
f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b)
f' sshT
@@ -658,12 +699,13 @@ instance Elt a => Elt (Mixed sh' a) where
mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
mconcat l@(M_Nest sh1 _ :| _) =
let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l)
- in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape result))) result
+ in M_Nest (shxTakeSh (Proxy @sh') sh1 (mshape result)) result
mrnf (M_Nest sh arr) = rnf sh `seq` mrnf arr
type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a)
+ -- This requires that @arr@ is not empty.
mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a)
mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShX (mshape arr)))))
@@ -676,17 +718,20 @@ instance Elt a => Elt (Mixed sh' a) where
marrayStrides (M_Nest _ arr) = marrayStrides arr
- mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs
+ mvecsWriteLinear :: forall s sh. Int -> Mixed sh' a -> MixedVecs s sh (Mixed sh' a) -> ST s ()
+ mvecsWriteLinear idx val (MV_Nest _ vecs) = mvecsWritePartialLinear (Proxy @sh) idx val vecs
- mvecsWritePartial :: forall sh1 sh2 s.
- IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
- -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
- -> ST s ()
- mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs)
+ mvecsWritePartialLinear
+ :: forall sh1 sh2 s.
+ Proxy sh1 -> Int -> Mixed sh2 (Mixed sh' a)
+ -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
+ -> ST s ()
+ mvecsWritePartialLinear proxy idx (M_Nest _ arr) (MV_Nest _ vecs)
| Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs
+ = mvecsWritePartialLinear proxy idx arr vecs
mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs
+ mvecsUnsafeFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsUnsafeFreeze (shxAppend sh sh') vecs
instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
memptyArrayUnsafe sh = M_Nest sh (memptyArrayUnsafe (shxAppend sh (shxCompleteZeros (knownShX @sh'))))
@@ -697,9 +742,28 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
where
sh' = mshape example
+ mvecsReplicate sh example = do
+ vecs <- mvecsUnsafeNew sh example
+ forM_ [0 .. shxSize sh - 1] $ \idx -> mvecsWriteLinear idx example vecs
+ -- this is a slow case, but the alternative, mvecsUnsafeNew with manual
+ -- writing in a loop, leads to every case being as slow
+ return vecs
+
mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
+-- | Given the shape of this array, an index and a value, write the value at
+-- that index in the vectors.
+{-# INLINE mvecsWrite #-}
+mvecsWrite :: Elt a => IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
+mvecsWrite sh idx val vecs = mvecsWriteLinear (ixxToLinear sh idx) val vecs
+
+-- | Given the shape of this array, an index and a value, write the value at
+-- that index in the vectors.
+{-# INLINE mvecsWritePartial #-}
+mvecsWritePartial :: forall sh sh' s a. Elt a => IShX sh -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+mvecsWritePartial sh idx val vecs = mvecsWritePartialLinear (Proxy @sh) (ixxToLinear sh idx) val vecs
+
-- TODO: should we provide a function that's just memptyArrayUnsafe but with a size==0 check? That may save someone a transpose somewhere
memptyArray :: forall sh a. KnownElt a => IShX sh -> Mixed (Just 0 : sh) a
memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh)
@@ -746,7 +810,7 @@ mgenerate sh f = case shxEnum sh of
when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $
error "Data.Array.Nested mgenerate: generated values do not have equal shapes"
mvecsWrite sh idx val vecs
- mvecsFreeze sh vecs
+ mvecsUnsafeFreeze sh vecs
-- | An optimized special case of 'mgenerate', where the function results
-- are of a primitive type and so there's not need to check that all shapes
@@ -759,19 +823,23 @@ mgeneratePrim sh f =
let g i = f (ixxFromLinear sh i)
in mfromVector sh $ VS.generate (shxSize sh) g
+{-# INLINEABLE msumOuter1PrimP #-}
msumOuter1PrimP :: forall sh n a. (Storable a, NumElt a)
=> Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
msumOuter1PrimP (M_Primitive (n :$% sh) arr) =
let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX
in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr)
+{-# INLINEABLE msumOuter1Prim #-}
msumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a)
=> Mixed (n : sh) a -> Mixed sh a
msumOuter1Prim = fromPrimitive . msumOuter1PrimP @sh @n @a . toPrimitive
+{-# INLINEABLE msumAllPrimP #-}
msumAllPrimP :: (Storable a, NumElt a) => Mixed sh (Primitive a) -> a
msumAllPrimP (M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr
+{-# INLINEABLE msumAllPrim #-}
msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a
msumAllPrim arr = msumAllPrimP (toPrimitive arr)
@@ -782,7 +850,7 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
sn :$% sh = mshape arr1
sm :$% _ = mshape arr2
ssh = ssxFromShX sh
- snm :: SMayNat () SNat (AddMaybe n m)
+ snm :: SMayNat () (AddMaybe n m)
snm = case (sn, sm) of
(SUnknown{}, _) -> SUnknown ()
(SKnown{}, SUnknown{}) -> SUnknown ()
@@ -792,15 +860,19 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
=> StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b
f ssh' = X.append (ssxAppend ssh ssh')
+{-# INLINEABLE mfromVectorP #-}
mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a)
mfromVectorP sh v = M_Primitive sh (X.fromVector sh v)
+{-# INLINEABLE mfromVector #-}
mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a
mfromVector sh v = fromPrimitive (mfromVectorP sh v)
+{-# INLINEABLE mtoVectorP #-}
mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a
mtoVectorP (M_Primitive _ v) = X.toVector v
+{-# INLINEABLE mtoVector #-}
mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a
mtoVector arr = mtoVectorP (toPrimitive arr)
@@ -856,7 +928,7 @@ mfromListLinear sh l = mreshape sh (mfromList1N (shxSize sh) l)
mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a
mfromList1Prim l =
let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
+ xarr = X.fromList1 l
in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
mfromList1PrimN :: PrimElt a => Int -> [a] -> Mixed '[Nothing] a
@@ -865,11 +937,15 @@ mfromList1PrimN n l =
Just sn -> mcastPartial (SKnown sn :!% ZKX) (SUnknown () :!% ZKX) Proxy (mfromList1PrimSN sn l)
Nothing -> error $ "mfromList1PrimN: length negative (" ++ show n ++ ")"
-mfromList1PrimSN :: PrimElt a => SNat n -> [a] -> Mixed '[Just n] a
+mfromList1PrimSN :: forall n a. PrimElt a => SNat n -> [a] -> Mixed '[Just n] a
mfromList1PrimSN sn l =
- let ssh = SKnown sn :!% ZKX
- xarr = X.fromList1 ssh l
- in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+ let ssh = SKnown sn :$% ZSX
+ in fromPrimitive $ M_Primitive ssh
+ $ if Storable.sizeOf (undefined :: a) > 0
+ then X.fromList1SN sn l
+ else case l of -- don't force the list if all elements are the same
+ a0 : _ -> X.replicateScal ssh a0
+ [] -> X.fromList1SN sn l
mfromListPrimLinear :: forall sh a. PrimElt a => IShX sh -> [a] -> Mixed sh a
mfromListPrimLinear sh l =
@@ -886,7 +962,7 @@ munScalar :: Elt a => Mixed '[] a -> a
munScalar arr = mindex arr ZIX
mnest :: forall sh sh' a. Elt a => StaticShX sh -> Mixed (sh ++ sh') a -> Mixed sh (Mixed sh' a)
-mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr
+mnest ssh arr = M_Nest (shxTakeSSX (Proxy @sh') ssh (mshape arr)) arr
munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a
munNest (M_Nest _ arr) = arr
@@ -999,6 +1075,7 @@ mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
ixxFromList (ssxFromShX sh) (numEltMaxIndex (shxRank sh) (fromO arr))
+{-# INLINEABLE mdot1Inner #-}
mdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
=> Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a
mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b))
@@ -1014,6 +1091,7 @@ mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primi
-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
-- Prefer 'mdot1Inner' if applicable.
+{-# INLINEABLE mdot #-}
mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a
mdot a b =
munScalar $
@@ -1032,11 +1110,13 @@ mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr
mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a
mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP
+{-# INLINE mliftPrim #-}
mliftPrim :: (PrimElt a, PrimElt b)
=> (a -> b)
-> Mixed sh a -> Mixed sh b
mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr))
+{-# INLINE mliftPrim2 #-}
mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c)
=> (a -> b -> c)
-> Mixed sh a -> Mixed sh b -> Mixed sh c