diff options
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 170 |
1 files changed, 125 insertions, 45 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 182943d..39f00fa 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -23,12 +23,14 @@ module Data.Array.Nested.Mixed where import Prelude hiding (mconcat) import Control.DeepSeq (NFData(..)) -import Control.Monad (forM_, when) +import Control.Monad (foldM_, forM_, when) import Control.Monad.ST +import Data.Array.Internal qualified as OI +import Data.Array.Internal.RankedG qualified as ORG +import Data.Array.Internal.RankedS qualified as ORS import Data.Array.RankedS qualified as S import Data.Bifunctor (bimap) import Data.Coerce -import Data.Foldable (toList) import Data.Int import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty(..)) @@ -39,6 +41,7 @@ import Data.Vector.Storable qualified as VS import Data.Vector.Storable.Mutable qualified as VSM import Foreign.C.Types (CInt) import Foreign.Storable (Storable) +import Foreign.Storable qualified as Storable import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) import GHC.TypeLits @@ -237,11 +240,13 @@ instance Elt a => NFData (Mixed sh a) where rnf = mrnf +{-# INLINE mliftNumElt1 #-} mliftNumElt1 :: (PrimElt a, PrimElt b) => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b) -> Mixed sh a -> Mixed sh b mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr)) +{-# INLINE mliftNumElt2 #-} mliftNumElt2 :: (PrimElt a, PrimElt b, PrimElt c) => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b -> S.Array (Rank sh) c) -> Mixed sh a -> Mixed sh b -> Mixed sh c @@ -310,7 +315,7 @@ class Elt a where -- | See 'mfromListOuter'. If the list does not have the given length, a -- runtime error is thrown. 'mfromListPrimSN' is faster if applicable. - mfromListOuterSN :: forall sh n. SNat n -> NonEmpty (Mixed sh a) -> Mixed (Just n : sh) a + mfromListOuterSN :: forall sh n. SNat n -> NonEmpty (Mixed sh a) -> Mixed (Just n : sh) a mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a] @@ -355,6 +360,9 @@ class Elt a where -- | Tree giving the shape of every array component. type ShapeTree a + -- | Produces an internal representation of a tree of shapes of (potentially) + -- nested arrays. If the argument is an array, it requires that the array + -- is not empty (it's guaranteed to crash early otherwise). mshapeTree :: a -> ShapeTree a mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool @@ -367,17 +375,21 @@ class Elt a where -- this mixed array. marrayStrides :: Mixed sh a -> Bag [Int] - -- | Given the shape of this array, an index and a value, write the value at + -- | Given a linear index and a value, write the value at -- that index in the vectors. - mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () + mvecsWriteLinear :: Int -> a -> MixedVecs s sh a -> ST s () - -- | Given the shape of this array, an index and a value, write the value at + -- | Given a linear index and a value, write the value at -- that index in the vectors. - mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + mvecsWritePartialLinear :: Proxy sh -> Int -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () -- | Given the shape of this array, finalise the vectors into 'XArray's. mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) + -- | 'mvecsFreeze' but without copying the mutable vectors before freezing + -- them. If the mutable vectors are mutated after this function, referential + -- transparency is broken. + mvecsUnsafeFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) -- | Element types for which we have evidence of the (static part of the) shape -- in a type class constraint. Compare the instance contexts of the instances @@ -393,11 +405,18 @@ class Elt a => KnownElt a where -- this vector and an example for the contents. mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) + -- | Create initialised vectors for this array type, given the shape of + -- this vector and the chosen element. + mvecsReplicate :: IShX sh -> a -> ST s (MixedVecs s sh a) + mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) -- Arrays of scalars are basically just arrays of scalars. instance Storable a => Elt (Primitive a) where + -- Somehow, INLINE here can increase allocation with GHC 9.14.1. + -- Maybe that happens in void instances such as @Primitive ()@. + {-# INLINEABLE mshape #-} mshape (M_Primitive sh _) = sh {-# INLINEABLE mindex #-} mindex (M_Primitive _ a) i = Primitive (X.index a i) @@ -405,10 +424,11 @@ instance Storable a => Elt (Primitive a) where mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i) mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) mfromListOuterSN sn l@(arr1 :| _) = - let sh = SKnown sn :$% mshape arr1 - in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l))) + let sh = mshape arr1 + in M_Primitive (SKnown sn :$% sh) (X.fromListOuterSN sn sh ((\(M_Primitive _ a) -> a) <$> l)) mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) + {-# INLINE mlift #-} mlift :: forall sh1 sh2. StaticShX sh2 -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) @@ -419,6 +439,7 @@ instance Storable a => Elt (Primitive a) where , let result = f ZKX a = M_Primitive (X.shape ssh2 result) result + {-# INLINE mlift2 #-} mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) @@ -430,6 +451,7 @@ instance Storable a => Elt (Primitive a) where , let result = f ZKX a b = M_Primitive (X.shape ssh3 result) result + {-# INLINE mliftL #-} mliftL :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) @@ -464,18 +486,22 @@ instance Storable a => Elt (Primitive a) where mshapeTreeIsEmpty _ () = False mshowShapeTree _ () = "()" marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr) - mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x + mvecsWriteLinear i (Primitive x) (MV_Primitive v) = VSM.write v i x - -- TODO: this use of toVector is suboptimal - mvecsWritePartial + -- TODO: this use of toVectorListT is suboptimal + mvecsWritePartialLinear :: forall sh' sh s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () - mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do + Proxy sh -> Int -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () + mvecsWritePartialLinear _ i (M_Primitive sh' arr@(XArray (ORS.A (ORG.A sht t)))) (MV_Primitive v) = do let arrsh = X.shape (ssxFromShX sh') arr - offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) - VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) + offset = i * shxSize arrsh + f off el = do + VS.copy (VSM.slice off (VS.length el) v) el + return $! off + VS.length el + foldM_ f offset (OI.toVectorListT sht t) mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v + mvecsUnsafeFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.unsafeFreeze v -- [PRIMITIVE ELEMENT TYPES LIST] deriving via Primitive Bool instance Elt Bool @@ -492,6 +518,7 @@ deriving via Primitive () instance Elt () instance Storable a => KnownElt (Primitive a) where memptyArrayUnsafe sh = M_Primitive sh (X.empty sh) mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh) + mvecsReplicate sh (Primitive a) = MV_Primitive <$> VSM.replicate (shxSize sh) a mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 -- [PRIMITIVE ELEMENT TYPES LIST] @@ -508,16 +535,22 @@ deriving via Primitive () instance KnownElt () -- Arrays of pairs are pairs of arrays. instance (Elt a, Elt b) => Elt (a, b) where + {-# INLINEABLE mshape #-} mshape (M_Tup2 a _) = mshape a + {-# INLINEABLE mindex #-} mindex (M_Tup2 a b) i = (mindex a i, mindex b i) + {-# INLINEABLE mindexPartial #-} mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) mfromListOuterSN sn l = M_Tup2 (mfromListOuterSN sn ((\(M_Tup2 x _) -> x) <$> l)) (mfromListOuterSN sn ((\(M_Tup2 _ y) -> y) <$> l)) mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b) + {-# INLINE mlift #-} mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) + {-# INLINE mlift2 #-} mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) + {-# INLINE mliftL #-} mliftL ssh2 f = let unzipT2l [] = ([], []) unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2) @@ -542,17 +575,19 @@ instance (Elt a, Elt b) => Elt (a, b) where mshapeTreeIsEmpty _ (t1, t2) = mshapeTreeIsEmpty (Proxy @a) t1 && mshapeTreeIsEmpty (Proxy @b) t2 mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b - mvecsWrite sh i (x, y) (MV_Tup2 a b) = do - mvecsWrite sh i x a - mvecsWrite sh i y b - mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do - mvecsWritePartial sh i x a - mvecsWritePartial sh i y b + mvecsWriteLinear i (x, y) (MV_Tup2 a b) = do + mvecsWriteLinear i x a + mvecsWriteLinear i y b + mvecsWritePartialLinear proxy i (M_Tup2 x y) (MV_Tup2 a b) = do + mvecsWritePartialLinear proxy i x a + mvecsWritePartialLinear proxy i y b mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b + mvecsUnsafeFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsUnsafeFreeze sh a <*> mvecsUnsafeFreeze sh b instance (KnownElt a, KnownElt b) => KnownElt (a, b) where memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh) mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y + mvecsReplicate sh (x, y) = MV_Tup2 <$> mvecsReplicate sh x <*> mvecsReplicate sh y mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) -- Arrays of arrays are just arrays, but with more dimensions. @@ -560,13 +595,16 @@ instance Elt a => Elt (Mixed sh' a) where -- TODO: this is quadratic in the nesting depth because it repeatedly -- truncates the shape vector to one a little shorter. Fix with a -- moverlongShape method, a prefix of which is mshape. + {-# INLINEABLE mshape #-} mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh mshape (M_Nest sh arr) - = fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr)) + = shxTakeSh (Proxy @sh') sh (mshape arr) + {-# INLINEABLE mindex #-} mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a mindex (M_Nest _ arr) = mindexPartial arr + {-# INLINEABLE mindexPartial #-} mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) mindexPartial (M_Nest sh arr) i @@ -581,16 +619,17 @@ instance Elt a => Elt (Mixed sh' a) where mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr) + {-# INLINE mlift #-} mlift :: forall sh1 sh2. StaticShX sh2 -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) mlift ssh2 f (M_Nest sh1 arr) = let result = mlift (ssxAppend ssh2 ssh') f' arr - (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result) + sh2 = shxTakeSSX (Proxy @sh') ssh2 (mshape result) in M_Nest sh2 result where - ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr))) + ssh' = ssxFromShX (shxDropSh @sh1 @sh' sh1 (mshape arr)) f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b f' sshT @@ -598,16 +637,17 @@ instance Elt a => Elt (Mixed sh' a) where , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) = f (ssxAppend ssh' sshT) + {-# INLINE mlift2 #-} mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) = let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2 - (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result) + sh3 = shxTakeSSX (Proxy @sh') ssh3 (mshape result) in M_Nest sh3 result where - ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1))) + ssh' = ssxFromShX (shxDropSh @sh1 @sh' sh1 (mshape arr1)) f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b f' sshT @@ -616,16 +656,17 @@ instance Elt a => Elt (Mixed sh' a) where , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) = f (ssxAppend ssh' sshT) + {-# INLINE mliftL #-} mliftL :: forall sh1 sh2. StaticShX sh2 -> (forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray (sh1 ++ shT) b) -> NonEmpty (XArray (sh2 ++ shT) b)) -> NonEmpty (Mixed sh1 (Mixed sh' a)) -> NonEmpty (Mixed sh2 (Mixed sh' a)) mliftL ssh2 f l@(M_Nest sh1 arr1 :| _) = let result = mliftL (ssxAppend ssh2 ssh') f' (fmap (\(M_Nest _ arr) -> arr) l) - (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result)) + sh2 = shxTakeSSX (Proxy @sh') ssh2 (mshape (NE.head result)) in fmap (M_Nest sh2) result where - ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1))) + ssh' = ssxFromShX (shxDropSh @sh1 @sh' sh1 (mshape arr1)) f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b) f' sshT @@ -658,12 +699,13 @@ instance Elt a => Elt (Mixed sh' a) where mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) mconcat l@(M_Nest sh1 _ :| _) = let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l) - in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape result))) result + in M_Nest (shxTakeSh (Proxy @sh') sh1 (mshape result)) result mrnf (M_Nest sh arr) = rnf sh `seq` mrnf arr type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a) + -- This requires that @arr@ is not empty. mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShX (mshape arr))))) @@ -676,17 +718,20 @@ instance Elt a => Elt (Mixed sh' a) where marrayStrides (M_Nest _ arr) = marrayStrides arr - mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs + mvecsWriteLinear :: forall s sh. Int -> Mixed sh' a -> MixedVecs s sh (Mixed sh' a) -> ST s () + mvecsWriteLinear idx val (MV_Nest _ vecs) = mvecsWritePartialLinear (Proxy @sh) idx val vecs - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) - -> ST s () - mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) + mvecsWritePartialLinear + :: forall sh1 sh2 s. + Proxy sh1 -> Int -> Mixed sh2 (Mixed sh' a) + -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) + -> ST s () + mvecsWritePartialLinear proxy idx (M_Nest _ arr) (MV_Nest _ vecs) | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs + = mvecsWritePartialLinear proxy idx arr vecs mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs + mvecsUnsafeFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsUnsafeFreeze (shxAppend sh sh') vecs instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where memptyArrayUnsafe sh = M_Nest sh (memptyArrayUnsafe (shxAppend sh (shxCompleteZeros (knownShX @sh')))) @@ -697,9 +742,28 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where where sh' = mshape example + mvecsReplicate sh example = do + vecs <- mvecsUnsafeNew sh example + forM_ [0 .. shxSize sh - 1] $ \idx -> mvecsWriteLinear idx example vecs + -- this is a slow case, but the alternative, mvecsUnsafeNew with manual + -- writing in a loop, leads to every case being as slow + return vecs + mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) +-- | Given the shape of this array, an index and a value, write the value at +-- that index in the vectors. +{-# INLINE mvecsWrite #-} +mvecsWrite :: Elt a => IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () +mvecsWrite sh idx val vecs = mvecsWriteLinear (ixxToLinear sh idx) val vecs + +-- | Given the shape of this array, an index and a value, write the value at +-- that index in the vectors. +{-# INLINE mvecsWritePartial #-} +mvecsWritePartial :: forall sh sh' s a. Elt a => IShX sh -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () +mvecsWritePartial sh idx val vecs = mvecsWritePartialLinear (Proxy @sh) (ixxToLinear sh idx) val vecs + -- TODO: should we provide a function that's just memptyArrayUnsafe but with a size==0 check? That may save someone a transpose somewhere memptyArray :: forall sh a. KnownElt a => IShX sh -> Mixed (Just 0 : sh) a memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh) @@ -746,7 +810,7 @@ mgenerate sh f = case shxEnum sh of when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ error "Data.Array.Nested mgenerate: generated values do not have equal shapes" mvecsWrite sh idx val vecs - mvecsFreeze sh vecs + mvecsUnsafeFreeze sh vecs -- | An optimized special case of 'mgenerate', where the function results -- are of a primitive type and so there's not need to check that all shapes @@ -759,19 +823,23 @@ mgeneratePrim sh f = let g i = f (ixxFromLinear sh i) in mfromVector sh $ VS.generate (shxSize sh) g +{-# INLINEABLE msumOuter1PrimP #-} msumOuter1PrimP :: forall sh n a. (Storable a, NumElt a) => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) msumOuter1PrimP (M_Primitive (n :$% sh) arr) = let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr) +{-# INLINEABLE msumOuter1Prim #-} msumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a) => Mixed (n : sh) a -> Mixed sh a msumOuter1Prim = fromPrimitive . msumOuter1PrimP @sh @n @a . toPrimitive +{-# INLINEABLE msumAllPrimP #-} msumAllPrimP :: (Storable a, NumElt a) => Mixed sh (Primitive a) -> a msumAllPrimP (M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr +{-# INLINEABLE msumAllPrim #-} msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a msumAllPrim arr = msumAllPrimP (toPrimitive arr) @@ -782,7 +850,7 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 sn :$% sh = mshape arr1 sm :$% _ = mshape arr2 ssh = ssxFromShX sh - snm :: SMayNat () SNat (AddMaybe n m) + snm :: SMayNat () (AddMaybe n m) snm = case (sn, sm) of (SUnknown{}, _) -> SUnknown () (SKnown{}, SUnknown{}) -> SUnknown () @@ -792,15 +860,19 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b f ssh' = X.append (ssxAppend ssh ssh') +{-# INLINEABLE mfromVectorP #-} mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) mfromVectorP sh v = M_Primitive sh (X.fromVector sh v) +{-# INLINEABLE mfromVector #-} mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a mfromVector sh v = fromPrimitive (mfromVectorP sh v) +{-# INLINEABLE mtoVectorP #-} mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a mtoVectorP (M_Primitive _ v) = X.toVector v +{-# INLINEABLE mtoVector #-} mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a mtoVector arr = mtoVectorP (toPrimitive arr) @@ -856,7 +928,7 @@ mfromListLinear sh l = mreshape sh (mfromList1N (shxSize sh) l) mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a mfromList1Prim l = let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l + xarr = X.fromList1 l in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr mfromList1PrimN :: PrimElt a => Int -> [a] -> Mixed '[Nothing] a @@ -865,11 +937,15 @@ mfromList1PrimN n l = Just sn -> mcastPartial (SKnown sn :!% ZKX) (SUnknown () :!% ZKX) Proxy (mfromList1PrimSN sn l) Nothing -> error $ "mfromList1PrimN: length negative (" ++ show n ++ ")" -mfromList1PrimSN :: PrimElt a => SNat n -> [a] -> Mixed '[Just n] a +mfromList1PrimSN :: forall n a. PrimElt a => SNat n -> [a] -> Mixed '[Just n] a mfromList1PrimSN sn l = - let ssh = SKnown sn :!% ZKX - xarr = X.fromList1 ssh l - in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr + let ssh = SKnown sn :$% ZSX + in fromPrimitive $ M_Primitive ssh + $ if Storable.sizeOf (undefined :: a) > 0 + then X.fromList1SN sn l + else case l of -- don't force the list if all elements are the same + a0 : _ -> X.replicateScal ssh a0 + [] -> X.fromList1SN sn l mfromListPrimLinear :: forall sh a. PrimElt a => IShX sh -> [a] -> Mixed sh a mfromListPrimLinear sh l = @@ -886,7 +962,7 @@ munScalar :: Elt a => Mixed '[] a -> a munScalar arr = mindex arr ZIX mnest :: forall sh sh' a. Elt a => StaticShX sh -> Mixed (sh ++ sh') a -> Mixed sh (Mixed sh' a) -mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr +mnest ssh arr = M_Nest (shxTakeSSX (Proxy @sh') ssh (mshape arr)) arr munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a munNest (M_Nest _ arr) = arr @@ -999,6 +1075,7 @@ mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = ixxFromList (ssxFromShX sh) (numEltMaxIndex (shxRank sh) (fromO arr)) +{-# INLINEABLE mdot1Inner #-} mdot1Inner :: forall sh n a. (PrimElt a, NumElt a) => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b)) @@ -1014,6 +1091,7 @@ mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primi -- | This has a temporary, suboptimal implementation in terms of 'mflatten'. -- Prefer 'mdot1Inner' if applicable. +{-# INLINEABLE mdot #-} mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a mdot a b = munScalar $ @@ -1032,11 +1110,13 @@ mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP +{-# INLINE mliftPrim #-} mliftPrim :: (PrimElt a, PrimElt b) => (a -> b) -> Mixed sh a -> Mixed sh b mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr)) +{-# INLINE mliftPrim2 #-} mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c) => (a -> b -> c) -> Mixed sh a -> Mixed sh b -> Mixed sh c |
