diff options
Diffstat (limited to 'src/Data/Array')
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 19 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 9 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked.hs | 10 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Base.hs | 4 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 10 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 4 | ||||
| -rw-r--r-- | src/Data/Array/XArray.hs | 1 |
7 files changed, 56 insertions, 1 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 2b5c5b6..ffbc993 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -408,6 +408,9 @@ class Elt a => KnownElt a where -- Arrays of scalars are basically just arrays of scalars. instance Storable a => Elt (Primitive a) where + -- Somehow, INLINE here can increase allocation with GHC 9.14.1. + -- Maybe that happens in void instances such as @Primitive ()@. + {-# INLINEABLE mshape #-} mshape (M_Primitive sh _) = sh {-# INLINEABLE mindex #-} mindex (M_Primitive _ a) i = Primitive (X.index a i) @@ -523,8 +526,11 @@ deriving via Primitive () instance KnownElt () -- Arrays of pairs are pairs of arrays. instance (Elt a, Elt b) => Elt (a, b) where + {-# INLINEABLE mshape #-} mshape (M_Tup2 a _) = mshape a + {-# INLINEABLE mindex #-} mindex (M_Tup2 a b) i = (mindex a i, mindex b i) + {-# INLINEABLE mindexPartial #-} mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) mfromListOuterSN sn l = @@ -580,13 +586,16 @@ instance Elt a => Elt (Mixed sh' a) where -- TODO: this is quadratic in the nesting depth because it repeatedly -- truncates the shape vector to one a little shorter. Fix with a -- moverlongShape method, a prefix of which is mshape. + {-# INLINEABLE mshape #-} mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh mshape (M_Nest sh arr) = fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr)) + {-# INLINEABLE mindex #-} mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a mindex (M_Nest _ arr) = mindexPartial arr + {-# INLINEABLE mindexPartial #-} mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) mindexPartial (M_Nest sh arr) i @@ -804,19 +813,23 @@ mgeneratePrim sh f = let g i = f (ixxFromLinear sh i) in mfromVector sh $ VS.generate (shxSize sh) g +{-# INLINEABLE msumOuter1PrimP #-} msumOuter1PrimP :: forall sh n a. (Storable a, NumElt a) => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) msumOuter1PrimP (M_Primitive (n :$% sh) arr) = let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr) +{-# INLINEABLE msumOuter1Prim #-} msumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a) => Mixed (n : sh) a -> Mixed sh a msumOuter1Prim = fromPrimitive . msumOuter1PrimP @sh @n @a . toPrimitive +{-# INLINEABLE msumAllPrimP #-} msumAllPrimP :: (Storable a, NumElt a) => Mixed sh (Primitive a) -> a msumAllPrimP (M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr +{-# INLINEABLE msumAllPrim #-} msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a msumAllPrim arr = msumAllPrimP (toPrimitive arr) @@ -837,15 +850,19 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b f ssh' = X.append (ssxAppend ssh ssh') +{-# INLINEABLE mfromVectorP #-} mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) mfromVectorP sh v = M_Primitive sh (X.fromVector sh v) +{-# INLINEABLE mfromVector #-} mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a mfromVector sh v = fromPrimitive (mfromVectorP sh v) +{-# INLINEABLE mtoVectorP #-} mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a mtoVectorP (M_Primitive _ v) = X.toVector v +{-# INLINEABLE mtoVector #-} mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a mtoVector arr = mtoVectorP (toPrimitive arr) @@ -1044,6 +1061,7 @@ mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = ixxFromList (ssxFromShX sh) (numEltMaxIndex (shxRank sh) (fromO arr)) +{-# INLINEABLE mdot1Inner #-} mdot1Inner :: forall sh n a. (PrimElt a, NumElt a) => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b)) @@ -1059,6 +1077,7 @@ mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primi -- | This has a temporary, suboptimal implementation in terms of 'mflatten'. -- Prefer 'mdot1Inner' if applicable. +{-# INLINEABLE mdot #-} mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a mdot a b = munScalar $ diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index de1c770..b3f0c2f 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -355,9 +355,16 @@ data ListH sh i where -- TODO: bring this UNPACK back when GHC no longer crashes: -- ConsKnown :: forall n sh i. {-# UNPACK #-} SNat n -> ListH sh i -> ListH (Just n : sh) i ConsKnown :: forall n sh i. SNat n -> ListH sh i -> ListH (Just n : sh) i -deriving instance Eq i => Eq (ListH sh i) deriving instance Ord i => Ord (ListH sh i) +-- A manually defined instance and this INLINEABLE is needed to specialize +-- mdot1Inner (otherwise GHC warns specialization breaks down here). +instance Eq i => Eq (ListH sh i) where + {-# INLINEABLE (==) #-} + ZH == ZH = True + ConsUnknown i1 sh1 == ConsUnknown i2 sh2 = i1 == i2 && sh1 == sh2 + ConsKnown _ sh1 == ConsKnown _ sh2 = sh1 == sh2 + #ifdef OXAR_DEFAULT_SHOW_INSTANCES deriving instance Show i => Show (ListH sh i) #else diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs index 8faff6d..b448685 100644 --- a/src/Data/Array/Nested/Ranked.hs +++ b/src/Data/Array/Nested/Ranked.hs @@ -94,6 +94,7 @@ rlift2 :: forall n1 n2 n3 a. Elt a -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2) +{-# INLINE rsumOuter1PrimP #-} rsumOuter1PrimP :: forall n a. (Storable a, NumElt a) => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) @@ -101,13 +102,16 @@ rsumOuter1PrimP (Ranked arr) | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = Ranked (msumOuter1PrimP arr) +{-# INLINEABLE rsumOuter1Prim #-} rsumOuter1Prim :: forall n a. (NumElt a, PrimElt a) => Ranked (n + 1) a -> Ranked n a rsumOuter1Prim = rfromPrimitive . rsumOuter1PrimP . rtoPrimitive +{-# INLINE rsumAllPrimP #-} rsumAllPrimP :: (Storable a, NumElt a) => Ranked n (Primitive a) -> a rsumAllPrimP (Ranked arr) = msumAllPrimP arr +{-# INLINE rsumAllPrim #-} rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a rsumAllPrim (Ranked arr) = msumAllPrim arr @@ -139,17 +143,21 @@ rappend arr1 arr2 rscalar :: Elt a => a -> Ranked 0 a rscalar x = Ranked (mscalar x) +{-# INLINEABLE rfromVectorP #-} rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a) rfromVectorP sh v | Dict <- lemKnownReplicate (shrRank sh) = Ranked (mfromVectorP (shxFromShR sh) v) +{-# INLINEABLE rfromVector #-} rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a rfromVector sh v = rfromPrimitive (rfromVectorP sh v) +{-# INLINEABLE rtoVectorP #-} rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a rtoVectorP = coerce mtoVectorP +{-# INLINEABLE rtoVector #-} rtoVector :: PrimElt a => Ranked n a -> VS.Vector a rtoVector = coerce mtoVector @@ -335,6 +343,7 @@ rmaxIndexPrim rarr@(Ranked arr) | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) = ixrFromIxX (mmaxIndexPrim arr) +{-# INLINEABLE rdot1Inner #-} rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a rdot1Inner arr1 arr2 | SNat <- rrank arr1 @@ -343,6 +352,7 @@ rdot1Inner arr1 arr2 -- | This has a temporary, suboptimal implementation in terms of 'mflatten'. -- Prefer 'rdot1Inner' if applicable. +{-# INLINE rdot #-} rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a rdot = coerce mdot diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs index 9d88815..beedbcf 100644 --- a/src/Data/Array/Nested/Ranked/Base.hs +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -82,9 +82,12 @@ newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed -- these instances allow them to also be used as elements of arrays, thus -- making them first-class in the API. instance Elt a => Elt (Ranked n a) where + {-# INLINE mshape #-} mshape (M_Ranked arr) = mshape arr + {-# INLINE mindex #-} mindex (M_Ranked arr) i = Ranked (mindex arr i) + {-# INLINE mindexPartial #-} mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a) mindexPartial (M_Ranked arr) i = coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ @@ -260,6 +263,7 @@ ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a ratan2Array = liftRanked2 matan2Array +{-# INLINE rshape #-} rshape :: Elt a => Ranked n a -> IShR n rshape (Ranked arr) = coerce (mshape arr) diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 5c52220..36ef24a 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -96,17 +96,21 @@ slift2 :: forall sh1 sh2 sh3 a. Elt a -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShX (shxFromShS sh3)) f arr1 arr2) +{-# INLINE ssumOuter1PrimP #-} ssumOuter1PrimP :: forall sh n a. (Storable a, NumElt a) => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) ssumOuter1PrimP (Shaped arr) = Shaped (msumOuter1PrimP arr) +{-# INLINEABLE ssumOuter1Prim #-} ssumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a) => Shaped (n : sh) a -> Shaped sh a ssumOuter1Prim = sfromPrimitive . ssumOuter1PrimP . stoPrimitive +{-# INLINE ssumAllPrimP #-} ssumAllPrimP :: (PrimElt a, NumElt a) => Shaped n (Primitive a) -> a ssumAllPrimP (Shaped arr) = msumAllPrimP arr +{-# INLINE ssumAllPrim #-} ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a ssumAllPrim (Shaped arr) = msumAllPrim arr @@ -126,15 +130,19 @@ sappend = coerce mappend sscalar :: Elt a => a -> Shaped '[] a sscalar x = Shaped (mscalar x) +{-# INLINEABLE sfromVectorP #-} sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) sfromVectorP sh v = Shaped (mfromVectorP (shxFromShS sh) v) +{-# INLINEABLE sfromVector #-} sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a sfromVector sh v = sfromPrimitive (sfromVectorP sh v) +{-# INLINEABLE stoVectorP #-} stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a stoVectorP = coerce mtoVectorP +{-# INLINEABLE stoVector #-} stoVector :: PrimElt a => Shaped sh a -> VS.Vector a stoVector = coerce mtoVector @@ -261,6 +269,7 @@ sminIndexPrim (Shaped arr) = ixsFromIxX (mminIndexPrim arr) smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh smaxIndexPrim (Shaped arr) = ixsFromIxX (mmaxIndexPrim arr) +{-# INLINEABLE sdot1Inner #-} sdot1Inner :: forall sh n a. (PrimElt a, NumElt a) => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2) @@ -272,6 +281,7 @@ sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2) -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2) _ -> error "unreachable" +{-# INLINE sdot #-} -- | This has a temporary, suboptimal implementation in terms of 'mflatten'. -- Prefer 'sdot1Inner' if applicable. sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index 8ef61dd..4b119c4 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -79,9 +79,12 @@ deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a)) instance Elt a => Elt (Shaped sh a) where + {-# INLINE mshape #-} mshape (M_Shaped arr) = mshape arr + {-# INLINE mindex #-} mindex (M_Shaped arr) i = Shaped (mindex arr i) + {-# INLINE mindexPartial #-} mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) mindexPartial (M_Shaped arr) i = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ @@ -257,5 +260,6 @@ satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped s satan2Array = liftShaped2 matan2Array +{-# INLINE sshape #-} sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh sshape (Shaped arr) = coerce (mshape arr) diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index e8039f6..3f23478 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -62,6 +62,7 @@ shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l go _ _ = error "Invalid shapeL" +{-# INLINEABLE fromVector #-} fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a fromVector sh v | Dict <- lemKnownNatRank sh |
