diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-11-27 14:45:17 +0100 |
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-11-30 20:51:03 +0100 |
| commit | a06c6416bab1639e5c3bd99b3c10de4dcf6c32f9 (patch) | |
| tree | c6a3806bc6b04b6efca5836bc173de01d9d41cfa | |
| parent | 4d762c306901e694286363ad0846d69a770acd63 (diff) | |
Inline all higher order shape functions
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 25 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 23 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 21 |
3 files changed, 55 insertions, 14 deletions
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index 23c3abf..a9ed2d0 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -105,21 +105,24 @@ listxEqual (n ::% sh) (m ::% sh') = Just Refl listxEqual _ _ = Nothing +{-# INLINE listxFmap #-} listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g listxFmap _ ZX = ZX listxFmap f (x ::% xs) = f x ::% listxFmap f xs -listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m -listxFold _ ZX = mempty -listxFold f (x ::% xs) = f x <> listxFold f xs +{-# INLINE listxFoldMap #-} +listxFoldMap :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m +listxFoldMap _ ZX = mempty +listxFoldMap f (x ::% xs) = f x <> listxFoldMap f xs listxLength :: ListX sh f -> Int -listxLength = getSum . listxFold (\_ -> Sum 1) +listxLength = getSum . listxFoldMap (\_ -> Sum 1) listxRank :: ListX sh f -> SNat (Rank sh) listxRank ZX = SNat listxRank (_ ::% l) | SNat <- listxRank l = SNat +{-# INLINE listxShow #-} listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS listxShow f l = showString "[" . go "" l . showString "]" where @@ -167,6 +170,7 @@ listxZip :: ListX sh f -> ListX sh g -> ListX sh (Product f g) listxZip ZX ZX = ZX listxZip (i ::% irest) (j ::% jrest) = Pair i j ::% listxZip irest jrest +{-# INLINE listxZipWith #-} listxZipWith :: (forall a. f a -> g a -> h a) -> ListX sh f -> ListX sh g -> ListX sh h listxZipWith _ ZX ZX = ZX @@ -206,10 +210,17 @@ instance Show i => Show (IxX sh i) where #endif instance Functor (IxX sh) where + {-# INLINE fmap #-} fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l) instance Foldable (IxX sh) where - foldMap f (IxX l) = listxFold (f . getConst) l + {-# INLINE foldMap #-} + foldMap f (IxX l) = listxFoldMap (f . getConst) l + {-# INLINE foldr #-} + foldr _ z ZIX = z + foldr f z (x :.% xs) = f x (foldr f z xs) + null ZIX = False + null _ = True instance NFData i => NFData (IxX sh i) @@ -257,6 +268,7 @@ ixxZip :: IxX sh i -> IxX sh j -> IxX sh (i, j) ixxZip ZIX ZIX = ZIX ixxZip (i :.% is) (j :.% js) = (i, j) :.% ixxZip is js +{-# INLINE ixxZipWith #-} ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k ixxZipWith _ ZIX ZIX = ZIX ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js @@ -317,6 +329,7 @@ instance TestEquality f => TestEquality (SMayNat i f) where testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl testEquality _ _ = Nothing +{-# INLINE fromSMayNat #-} fromSMayNat :: (n ~ Nothing => i -> r) -> (forall m. n ~ Just m => f m -> r) -> SMayNat i f n -> r @@ -366,6 +379,7 @@ instance Show i => Show (ShX sh i) where #endif instance Functor (ShX sh) where + {-# INLINE fmap #-} fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l) instance NFData i => NFData (ShX sh i) where @@ -472,6 +486,7 @@ shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh' shxTakeSSX _ ZKX _ = ZSX shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh +{-# INLINE shxZipWith #-} shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n) -> ShX sh i -> ShX sh j -> ShX sh k shxZipWith _ ZSX ZSX = ZSX diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index 50338d2..989d7d1 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -1,8 +1,6 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} @@ -56,8 +54,6 @@ data ListR n i where (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i deriving instance Eq i => Eq (ListR n i) deriving instance Ord i => Ord (ListR n i) -deriving instance Functor (ListR n) -deriving instance Foldable (ListR n) infixr 3 ::: #ifdef OXAR_DEFAULT_SHOW_INSTANCES @@ -71,6 +67,21 @@ instance NFData i => NFData (ListR n i) where rnf ZR = () rnf (x ::: l) = rnf x `seq` rnf l +instance Functor (ListR n) where + {-# INLINE fmap #-} + fmap _ ZR = ZR + fmap f (x ::: xs) = f x ::: fmap f xs + +instance Foldable (ListR n) where + {-# INLINE foldMap #-} + foldMap _ ZR = mempty + foldMap f (x ::: xs) = f x <> foldMap f xs + {-# INLINE foldr #-} + foldr _ z ZR = z + foldr f z (x ::: xs) = f x (foldr f z xs) + null ZR = False + null _ = True + data UnconsListRRes i n1 = forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1) @@ -95,6 +106,7 @@ listrEqual (i ::: sh) (j ::: sh') = Just Refl listrEqual _ _ = Nothing +{-# INLINE listrShow #-} listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS listrShow f l = showString "[" . go "" l . showString "]" where @@ -145,6 +157,7 @@ listrZip ZR ZR = ZR listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest listrZip _ _ = error "listrZip: impossible pattern needlessly required" +{-# INLINE listrZipWith #-} listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k listrZipWith _ ZR ZR = ZR listrZipWith f (i ::: irest) (j ::: jrest) = @@ -244,6 +257,7 @@ ixrAppend = coerce (listrAppend @_ @i) ixrZip :: IxR n i -> IxR n j -> IxR n (i, j) ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2 +{-# INLINE ixrZipWith #-} ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 @@ -328,6 +342,7 @@ shrAppend = coerce (listrAppend @_ @i) shrZip :: ShR n i -> ShR n j -> ShR n (i, j) shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2 +{-# INLINE shrZipWith #-} shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2 diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index 218caaa..c1e687a 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -102,13 +102,15 @@ listsEqual (n ::$ sh) (m ::$ sh') = Just Refl listsEqual _ _ = Nothing +{-# INLINE listsFmap #-} listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g listsFmap _ ZS = ZS listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs -listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m -listsFold _ ZS = mempty -listsFold f (x ::$ xs) = f x <> listsFold f xs +{-# INLINE listsFoldMap #-} +listsFoldMap :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m +listsFoldMap _ ZS = mempty +listsFoldMap f (x ::$ xs) = f x <> listsFoldMap f xs listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS listsShow f l = showString "[" . go "" l . showString "]" @@ -118,7 +120,7 @@ listsShow f l = showString "[" . go "" l . showString "]" go prefix (x ::$ xs) = showString prefix . f x . go "," xs listsLength :: ListS sh f -> Int -listsLength = getSum . listsFold (\_ -> Sum 1) +listsLength = getSum . listsFoldMap (\_ -> Sum 1) listsRank :: ListS sh f -> SNat (Rank sh) listsRank ZS = SNat @@ -150,6 +152,7 @@ listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g) listsZip ZS ZS = ZS listsZip (i ::$ is) (j ::$ js) = Fun.Pair i j ::$ listsZip is js +{-# INLINE listsZipWith #-} listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g -> ListS sh h listsZipWith _ ZS ZS = ZS @@ -217,10 +220,17 @@ instance Show i => Show (IxS sh i) where #endif instance Functor (IxS sh) where + {-# INLINE fmap #-} fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l) instance Foldable (IxS sh) where - foldMap f (IxS l) = listsFold (f . getConst) l + {-# INLINE foldMap #-} + foldMap f (IxS l) = listsFoldMap (f . getConst) l + {-# INLINE foldr #-} + foldr _ z ZIS = z + foldr f z (x :.$ xs) = f x (foldr f z xs) + null ZIS = False + null _ = True instance NFData i => NFData (IxS sh i) @@ -259,6 +269,7 @@ ixsZip :: IxS sh i -> IxS sh j -> IxS sh (i, j) ixsZip ZIS ZIS = ZIS ixsZip (i :.$ is) (j :.$ js) = (i, j) :.$ ixsZip is js +{-# INLINE ixsZipWith #-} ixsZipWith :: (i -> j -> k) -> IxS sh i -> IxS sh j -> IxS sh k ixsZipWith _ ZIS ZIS = ZIS ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js |
