aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Nested/Mixed.hs19
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs9
-rw-r--r--src/Data/Array/Nested/Ranked.hs10
-rw-r--r--src/Data/Array/Nested/Ranked/Base.hs4
-rw-r--r--src/Data/Array/Nested/Shaped.hs10
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs4
-rw-r--r--src/Data/Array/XArray.hs1
7 files changed, 56 insertions, 1 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 2b5c5b6..ffbc993 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -408,6 +408,9 @@ class Elt a => KnownElt a where
-- Arrays of scalars are basically just arrays of scalars.
instance Storable a => Elt (Primitive a) where
+ -- Somehow, INLINE here can increase allocation with GHC 9.14.1.
+ -- Maybe that happens in void instances such as @Primitive ()@.
+ {-# INLINEABLE mshape #-}
mshape (M_Primitive sh _) = sh
{-# INLINEABLE mindex #-}
mindex (M_Primitive _ a) i = Primitive (X.index a i)
@@ -523,8 +526,11 @@ deriving via Primitive () instance KnownElt ()
-- Arrays of pairs are pairs of arrays.
instance (Elt a, Elt b) => Elt (a, b) where
+ {-# INLINEABLE mshape #-}
mshape (M_Tup2 a _) = mshape a
+ {-# INLINEABLE mindex #-}
mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
+ {-# INLINEABLE mindexPartial #-}
mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
mfromListOuterSN sn l =
@@ -580,13 +586,16 @@ instance Elt a => Elt (Mixed sh' a) where
-- TODO: this is quadratic in the nesting depth because it repeatedly
-- truncates the shape vector to one a little shorter. Fix with a
-- moverlongShape method, a prefix of which is mshape.
+ {-# INLINEABLE mshape #-}
mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh
mshape (M_Nest sh arr)
= fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr))
+ {-# INLINEABLE mindex #-}
mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a
mindex (M_Nest _ arr) = mindexPartial arr
+ {-# INLINEABLE mindexPartial #-}
mindexPartial :: forall sh1 sh2.
Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
mindexPartial (M_Nest sh arr) i
@@ -804,19 +813,23 @@ mgeneratePrim sh f =
let g i = f (ixxFromLinear sh i)
in mfromVector sh $ VS.generate (shxSize sh) g
+{-# INLINEABLE msumOuter1PrimP #-}
msumOuter1PrimP :: forall sh n a. (Storable a, NumElt a)
=> Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
msumOuter1PrimP (M_Primitive (n :$% sh) arr) =
let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX
in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr)
+{-# INLINEABLE msumOuter1Prim #-}
msumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a)
=> Mixed (n : sh) a -> Mixed sh a
msumOuter1Prim = fromPrimitive . msumOuter1PrimP @sh @n @a . toPrimitive
+{-# INLINEABLE msumAllPrimP #-}
msumAllPrimP :: (Storable a, NumElt a) => Mixed sh (Primitive a) -> a
msumAllPrimP (M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr
+{-# INLINEABLE msumAllPrim #-}
msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a
msumAllPrim arr = msumAllPrimP (toPrimitive arr)
@@ -837,15 +850,19 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
=> StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b
f ssh' = X.append (ssxAppend ssh ssh')
+{-# INLINEABLE mfromVectorP #-}
mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a)
mfromVectorP sh v = M_Primitive sh (X.fromVector sh v)
+{-# INLINEABLE mfromVector #-}
mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a
mfromVector sh v = fromPrimitive (mfromVectorP sh v)
+{-# INLINEABLE mtoVectorP #-}
mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a
mtoVectorP (M_Primitive _ v) = X.toVector v
+{-# INLINEABLE mtoVector #-}
mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a
mtoVector arr = mtoVectorP (toPrimitive arr)
@@ -1044,6 +1061,7 @@ mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
ixxFromList (ssxFromShX sh) (numEltMaxIndex (shxRank sh) (fromO arr))
+{-# INLINEABLE mdot1Inner #-}
mdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
=> Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a
mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b))
@@ -1059,6 +1077,7 @@ mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primi
-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
-- Prefer 'mdot1Inner' if applicable.
+{-# INLINEABLE mdot #-}
mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a
mdot a b =
munScalar $
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index de1c770..b3f0c2f 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -355,9 +355,16 @@ data ListH sh i where
-- TODO: bring this UNPACK back when GHC no longer crashes:
-- ConsKnown :: forall n sh i. {-# UNPACK #-} SNat n -> ListH sh i -> ListH (Just n : sh) i
ConsKnown :: forall n sh i. SNat n -> ListH sh i -> ListH (Just n : sh) i
-deriving instance Eq i => Eq (ListH sh i)
deriving instance Ord i => Ord (ListH sh i)
+-- A manually defined instance and this INLINEABLE is needed to specialize
+-- mdot1Inner (otherwise GHC warns specialization breaks down here).
+instance Eq i => Eq (ListH sh i) where
+ {-# INLINEABLE (==) #-}
+ ZH == ZH = True
+ ConsUnknown i1 sh1 == ConsUnknown i2 sh2 = i1 == i2 && sh1 == sh2
+ ConsKnown _ sh1 == ConsKnown _ sh2 = sh1 == sh2
+
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
deriving instance Show i => Show (ListH sh i)
#else
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
index 8faff6d..b448685 100644
--- a/src/Data/Array/Nested/Ranked.hs
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -94,6 +94,7 @@ rlift2 :: forall n1 n2 n3 a. Elt a
-> Ranked n1 a -> Ranked n2 a -> Ranked n3 a
rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2)
+{-# INLINE rsumOuter1PrimP #-}
rsumOuter1PrimP :: forall n a.
(Storable a, NumElt a)
=> Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
@@ -101,13 +102,16 @@ rsumOuter1PrimP (Ranked arr)
| Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
= Ranked (msumOuter1PrimP arr)
+{-# INLINEABLE rsumOuter1Prim #-}
rsumOuter1Prim :: forall n a. (NumElt a, PrimElt a)
=> Ranked (n + 1) a -> Ranked n a
rsumOuter1Prim = rfromPrimitive . rsumOuter1PrimP . rtoPrimitive
+{-# INLINE rsumAllPrimP #-}
rsumAllPrimP :: (Storable a, NumElt a) => Ranked n (Primitive a) -> a
rsumAllPrimP (Ranked arr) = msumAllPrimP arr
+{-# INLINE rsumAllPrim #-}
rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a
rsumAllPrim (Ranked arr) = msumAllPrim arr
@@ -139,17 +143,21 @@ rappend arr1 arr2
rscalar :: Elt a => a -> Ranked 0 a
rscalar x = Ranked (mscalar x)
+{-# INLINEABLE rfromVectorP #-}
rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a)
rfromVectorP sh v
| Dict <- lemKnownReplicate (shrRank sh)
= Ranked (mfromVectorP (shxFromShR sh) v)
+{-# INLINEABLE rfromVector #-}
rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a
rfromVector sh v = rfromPrimitive (rfromVectorP sh v)
+{-# INLINEABLE rtoVectorP #-}
rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a
rtoVectorP = coerce mtoVectorP
+{-# INLINEABLE rtoVector #-}
rtoVector :: PrimElt a => Ranked n a -> VS.Vector a
rtoVector = coerce mtoVector
@@ -335,6 +343,7 @@ rmaxIndexPrim rarr@(Ranked arr)
| Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
= ixrFromIxX (mmaxIndexPrim arr)
+{-# INLINEABLE rdot1Inner #-}
rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a
rdot1Inner arr1 arr2
| SNat <- rrank arr1
@@ -343,6 +352,7 @@ rdot1Inner arr1 arr2
-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
-- Prefer 'rdot1Inner' if applicable.
+{-# INLINE rdot #-}
rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a
rdot = coerce mdot
diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs
index 9d88815..beedbcf 100644
--- a/src/Data/Array/Nested/Ranked/Base.hs
+++ b/src/Data/Array/Nested/Ranked/Base.hs
@@ -82,9 +82,12 @@ newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed
-- these instances allow them to also be used as elements of arrays, thus
-- making them first-class in the API.
instance Elt a => Elt (Ranked n a) where
+ {-# INLINE mshape #-}
mshape (M_Ranked arr) = mshape arr
+ {-# INLINE mindex #-}
mindex (M_Ranked arr) i = Ranked (mindex arr i)
+ {-# INLINE mindexPartial #-}
mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a)
mindexPartial (M_Ranked arr) i =
coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
@@ -260,6 +263,7 @@ ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a
ratan2Array = liftRanked2 matan2Array
+{-# INLINE rshape #-}
rshape :: Elt a => Ranked n a -> IShR n
rshape (Ranked arr) = coerce (mshape arr)
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index 5c52220..36ef24a 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -96,17 +96,21 @@ slift2 :: forall sh1 sh2 sh3 a. Elt a
-> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a
slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShX (shxFromShS sh3)) f arr1 arr2)
+{-# INLINE ssumOuter1PrimP #-}
ssumOuter1PrimP :: forall sh n a. (Storable a, NumElt a)
=> Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
ssumOuter1PrimP (Shaped arr) = Shaped (msumOuter1PrimP arr)
+{-# INLINEABLE ssumOuter1Prim #-}
ssumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a)
=> Shaped (n : sh) a -> Shaped sh a
ssumOuter1Prim = sfromPrimitive . ssumOuter1PrimP . stoPrimitive
+{-# INLINE ssumAllPrimP #-}
ssumAllPrimP :: (PrimElt a, NumElt a) => Shaped n (Primitive a) -> a
ssumAllPrimP (Shaped arr) = msumAllPrimP arr
+{-# INLINE ssumAllPrim #-}
ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a
ssumAllPrim (Shaped arr) = msumAllPrim arr
@@ -126,15 +130,19 @@ sappend = coerce mappend
sscalar :: Elt a => a -> Shaped '[] a
sscalar x = Shaped (mscalar x)
+{-# INLINEABLE sfromVectorP #-}
sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a)
sfromVectorP sh v = Shaped (mfromVectorP (shxFromShS sh) v)
+{-# INLINEABLE sfromVector #-}
sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a
sfromVector sh v = sfromPrimitive (sfromVectorP sh v)
+{-# INLINEABLE stoVectorP #-}
stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a
stoVectorP = coerce mtoVectorP
+{-# INLINEABLE stoVector #-}
stoVector :: PrimElt a => Shaped sh a -> VS.Vector a
stoVector = coerce mtoVector
@@ -261,6 +269,7 @@ sminIndexPrim (Shaped arr) = ixsFromIxX (mminIndexPrim arr)
smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
smaxIndexPrim (Shaped arr) = ixsFromIxX (mmaxIndexPrim arr)
+{-# INLINEABLE sdot1Inner #-}
sdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
=> Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a
sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2)
@@ -272,6 +281,7 @@ sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2)
-> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2)
_ -> error "unreachable"
+{-# INLINE sdot #-}
-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
-- Prefer 'sdot1Inner' if applicable.
sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
index 8ef61dd..4b119c4 100644
--- a/src/Data/Array/Nested/Shaped/Base.hs
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -79,9 +79,12 @@ deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped
newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a))
instance Elt a => Elt (Shaped sh a) where
+ {-# INLINE mshape #-}
mshape (M_Shaped arr) = mshape arr
+ {-# INLINE mindex #-}
mindex (M_Shaped arr) i = Shaped (mindex arr i)
+ {-# INLINE mindexPartial #-}
mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
mindexPartial (M_Shaped arr) i =
coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
@@ -257,5 +260,6 @@ satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped s
satan2Array = liftShaped2 matan2Array
+{-# INLINE sshape #-}
sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
sshape (Shaped arr) = coerce (mshape arr)
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs
index e8039f6..3f23478 100644
--- a/src/Data/Array/XArray.hs
+++ b/src/Data/Array/XArray.hs
@@ -62,6 +62,7 @@ shape = \ssh (XArray arr) -> go ssh (S.shapeL arr)
go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l
go _ _ = error "Invalid shapeL"
+{-# INLINEABLE fromVector #-}
fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a
fromVector sh v
| Dict <- lemKnownNatRank sh