diff options
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 2b5c5b6..ffbc993 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -408,6 +408,9 @@ class Elt a => KnownElt a where -- Arrays of scalars are basically just arrays of scalars. instance Storable a => Elt (Primitive a) where + -- Somehow, INLINE here can increase allocation with GHC 9.14.1. + -- Maybe that happens in void instances such as @Primitive ()@. + {-# INLINEABLE mshape #-} mshape (M_Primitive sh _) = sh {-# INLINEABLE mindex #-} mindex (M_Primitive _ a) i = Primitive (X.index a i) @@ -523,8 +526,11 @@ deriving via Primitive () instance KnownElt () -- Arrays of pairs are pairs of arrays. instance (Elt a, Elt b) => Elt (a, b) where + {-# INLINEABLE mshape #-} mshape (M_Tup2 a _) = mshape a + {-# INLINEABLE mindex #-} mindex (M_Tup2 a b) i = (mindex a i, mindex b i) + {-# INLINEABLE mindexPartial #-} mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) mfromListOuterSN sn l = @@ -580,13 +586,16 @@ instance Elt a => Elt (Mixed sh' a) where -- TODO: this is quadratic in the nesting depth because it repeatedly -- truncates the shape vector to one a little shorter. Fix with a -- moverlongShape method, a prefix of which is mshape. + {-# INLINEABLE mshape #-} mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh mshape (M_Nest sh arr) = fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr)) + {-# INLINEABLE mindex #-} mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a mindex (M_Nest _ arr) = mindexPartial arr + {-# INLINEABLE mindexPartial #-} mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) mindexPartial (M_Nest sh arr) i @@ -804,19 +813,23 @@ mgeneratePrim sh f = let g i = f (ixxFromLinear sh i) in mfromVector sh $ VS.generate (shxSize sh) g +{-# INLINEABLE msumOuter1PrimP #-} msumOuter1PrimP :: forall sh n a. (Storable a, NumElt a) => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) msumOuter1PrimP (M_Primitive (n :$% sh) arr) = let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr) +{-# INLINEABLE msumOuter1Prim #-} msumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a) => Mixed (n : sh) a -> Mixed sh a msumOuter1Prim = fromPrimitive . msumOuter1PrimP @sh @n @a . toPrimitive +{-# INLINEABLE msumAllPrimP #-} msumAllPrimP :: (Storable a, NumElt a) => Mixed sh (Primitive a) -> a msumAllPrimP (M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr +{-# INLINEABLE msumAllPrim #-} msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a msumAllPrim arr = msumAllPrimP (toPrimitive arr) @@ -837,15 +850,19 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b f ssh' = X.append (ssxAppend ssh ssh') +{-# INLINEABLE mfromVectorP #-} mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) mfromVectorP sh v = M_Primitive sh (X.fromVector sh v) +{-# INLINEABLE mfromVector #-} mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a mfromVector sh v = fromPrimitive (mfromVectorP sh v) +{-# INLINEABLE mtoVectorP #-} mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a mtoVectorP (M_Primitive _ v) = X.toVector v +{-# INLINEABLE mtoVector #-} mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a mtoVector arr = mtoVectorP (toPrimitive arr) @@ -1044,6 +1061,7 @@ mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = ixxFromList (ssxFromShX sh) (numEltMaxIndex (shxRank sh) (fromO arr)) +{-# INLINEABLE mdot1Inner #-} mdot1Inner :: forall sh n a. (PrimElt a, NumElt a) => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b)) @@ -1059,6 +1077,7 @@ mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primi -- | This has a temporary, suboptimal implementation in terms of 'mflatten'. -- Prefer 'mdot1Inner' if applicable. +{-# INLINEABLE mdot #-} mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a mdot a b = munScalar $ |
