From 21d3d6190bbdaf6ca626ac550dcee26e02318442 Mon Sep 17 00:00:00 2001 From: Mikolaj Konarski Date: Mon, 15 Dec 2025 10:47:11 +0100 Subject: Optimize the representation of ListH --- src/Data/Array/Nested/Mixed/Shape.hs | 89 ++++++++++++++++++++++++------------ src/Data/Array/Nested/Permutation.hs | 20 +++++--- 2 files changed, 73 insertions(+), 36 deletions(-) (limited to 'src/Data') diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index c8c9a7b..dc4063c 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -301,7 +301,7 @@ ixxToLinear = \sh i -> go sh i 0 data SMayNat i n where SUnknown :: i -> SMayNat i Nothing - SKnown :: {-# UNPACK #-} SNat n -> SMayNat i (Just n) + SKnown :: SNat n -> SMayNat i (Just n) deriving instance Show i => Show (SMayNat i n) deriving instance Eq i => Eq (SMayNat i n) deriving instance Ord i => Ord (SMayNat i n) @@ -340,10 +340,10 @@ type role ListH nominal representational type ListH :: [Maybe Nat] -> Type -> Type data ListH sh i where ZH :: ListH '[] i - (::#) :: forall n sh i. SMayNat i n -> ListH sh i -> ListH (n : sh) i + ConsUnknown :: forall sh i. i -> ListH sh i -> ListH (Nothing : sh) i + ConsKnown :: forall n sh i. {-# UNPACK #-} SNat n -> ListH sh i -> ListH (Just n : sh) i deriving instance Eq i => Eq (ListH sh i) deriving instance Ord i => Ord (ListH sh i) -infixr 3 ::# #ifdef OXAR_DEFAULT_SHOW_INSTANCES deriving instance Show i => Show (ListH sh i) @@ -352,14 +352,16 @@ instance Show i => Show (ListH sh i) where showsPrec _ = listhShow shows #endif -instance (forall n. NFData (SMayNat i n)) => NFData (ListH sh i) where +instance NFData i => NFData (ListH sh i) where rnf ZH = () - rnf (x ::# l) = rnf x `seq` rnf l + rnf (x `ConsUnknown` l) = rnf x `seq` rnf l + rnf (SNat `ConsKnown` l) = rnf l data UnconsListHRes i sh1 = forall n sh. (n : sh ~ sh1) => UnconsListHRes (ListH sh i) (SMayNat i n) listhUncons :: ListH sh1 i -> Maybe (UnconsListHRes i sh1) -listhUncons (i ::# shl') = Just (UnconsListHRes shl' i) +listhUncons (i `ConsUnknown` shl') = Just (UnconsListHRes shl' (SUnknown i)) +listhUncons (i `ConsKnown` shl') = Just (UnconsListHRes shl' (SKnown i)) listhUncons ZH = Nothing -- | This checks only whether the types are equal; if the elements of the list @@ -367,7 +369,10 @@ listhUncons ZH = Nothing -- 'testEquality', except on the penultimate type parameter. listhEqType :: ListH sh i -> ListH sh' i -> Maybe (sh :~: sh') listhEqType ZH ZH = Just Refl -listhEqType (n ::# sh) (m ::# sh') +listhEqType (_ `ConsUnknown` sh) (_ `ConsUnknown` sh') + | Just Refl <- listhEqType sh sh' + = Just Refl +listhEqType (n `ConsKnown` sh) (m `ConsKnown` sh') | Just Refl <- testEquality n m , Just Refl <- listhEqType sh sh' = Just Refl @@ -378,9 +383,12 @@ listhEqType _ _ = Nothing -- in the @some@ package (except on the penultimate type parameter). listhEqual :: Eq i => ListH sh i -> ListH sh' i -> Maybe (sh :~: sh') listhEqual ZH ZH = Just Refl -listhEqual (n ::# sh) (m ::# sh') +listhEqual (n `ConsUnknown` sh) (m `ConsUnknown` sh') + | n == m + , Just Refl <- listhEqual sh sh' + = Just Refl +listhEqual (n `ConsKnown` sh) (m `ConsKnown` sh') | Just Refl <- testEquality n m - , n == m , Just Refl <- listhEqual sh sh' = Just Refl listhEqual _ _ = Nothing @@ -388,19 +396,24 @@ listhEqual _ _ = Nothing {-# INLINE listhFmap #-} listhFmap :: (forall n. SMayNat i n -> SMayNat j n) -> ListH sh i -> ListH sh j listhFmap _ ZH = ZH -listhFmap f (x ::# xs) = f x ::# listhFmap f xs +listhFmap f (x `ConsUnknown` xs) = case f (SUnknown x) of + SUnknown y -> y `ConsUnknown` listhFmap f xs +listhFmap f (x `ConsKnown` xs) = case f (SKnown x) of + SKnown y -> y `ConsKnown` listhFmap f xs {-# INLINE listhFoldMap #-} listhFoldMap :: Monoid m => (forall n. SMayNat i n -> m) -> ListH sh i -> m listhFoldMap _ ZH = mempty -listhFoldMap f (x ::# xs) = f x <> listhFoldMap f xs +listhFoldMap f (x `ConsUnknown` xs) = f (SUnknown x) <> listhFoldMap f xs +listhFoldMap f (x `ConsKnown` xs) = f (SKnown x) <> listhFoldMap f xs listhLength :: ListH sh i -> Int listhLength = getSum . listhFoldMap (\_ -> Sum 1) listhRank :: ListH sh i -> SNat (Rank sh) listhRank ZH = SNat -listhRank (_ ::# l) | SNat <- listhRank l = SNat +listhRank (_ `ConsUnknown` l) | SNat <- listhRank l = SNat +listhRank (_ `ConsKnown` l) | SNat <- listhRank l = SNat {-# INLINE listhShow #-} listhShow :: forall sh i. (forall n. SMayNat i n -> ShowS) -> ListH sh i -> ShowS @@ -408,29 +421,44 @@ listhShow f l = showString "[" . go "" l . showString "]" where go :: String -> ListH sh' i -> ShowS go _ ZH = id - go prefix (x ::# xs) = showString prefix . f x . go "," xs + go prefix (x `ConsUnknown` xs) = showString prefix . f (SUnknown x) . go "," xs + go prefix (x `ConsKnown` xs) = showString prefix . f (SKnown x) . go "," xs listhHead :: ListH (mn ': sh) i -> SMayNat i mn -listhHead (i ::# _) = i +listhHead (i `ConsUnknown` _) = SUnknown i +listhHead (i `ConsKnown` _) = SKnown i listhTail :: ListH (n : sh) i -> ListH sh i -listhTail (_ ::# sh) = sh +listhTail (_ `ConsUnknown` sh) = sh +listhTail (_ `ConsKnown` sh) = sh listhAppend :: ListH sh i -> ListH sh' i -> ListH (sh ++ sh') i listhAppend ZH idx' = idx' -listhAppend (i ::# idx) idx' = i ::# listhAppend idx idx' +listhAppend (i `ConsUnknown` idx) idx' = i `ConsUnknown` listhAppend idx idx' +listhAppend (i `ConsKnown` idx) idx' = i `ConsKnown` listhAppend idx idx' listhDrop :: forall i j sh sh'. ListH sh j -> ListH (sh ++ sh') i -> ListH sh' i listhDrop ZH long = long -listhDrop (_ ::# short) long = case long of _ ::# long' -> listhDrop short long' +listhDrop (_ `ConsUnknown` short) long = case long of + _ `ConsUnknown` long' -> listhDrop short long' +listhDrop (_ `ConsKnown` short) long = case long of + _ `ConsKnown` long' -> listhDrop short long' listhInit :: forall i n sh. ListH (n : sh) i -> ListH (Init (n : sh)) i -listhInit (i ::# sh@(_ ::# _)) = i ::# listhInit sh -listhInit (_ ::# ZH) = ZH +listhInit (i `ConsUnknown` sh@(_ `ConsUnknown` _)) = i `ConsUnknown` listhInit sh +listhInit (i `ConsUnknown` sh@(_ `ConsKnown` _)) = i `ConsUnknown` listhInit sh +listhInit (_ `ConsUnknown` ZH) = ZH +listhInit (i `ConsKnown` sh@(_ `ConsUnknown` _)) = i `ConsKnown` listhInit sh +listhInit (i `ConsKnown` sh@(_ `ConsKnown` _)) = i `ConsKnown` listhInit sh +listhInit (_ `ConsKnown` ZH) = ZH listhLast :: forall i n sh. ListH (n : sh) i -> SMayNat i (Last (n : sh)) -listhLast (_ ::# sh@(_ ::# _)) = listhLast sh -listhLast (x ::# ZH) = x +listhLast (_ `ConsUnknown` sh@(_ `ConsUnknown` _)) = listhLast sh +listhLast (_ `ConsUnknown` sh@(_ `ConsKnown` _)) = listhLast sh +listhLast (x `ConsUnknown` ZH) = SUnknown x +listhLast (_ `ConsKnown` sh@(_ `ConsUnknown` _)) = listhLast sh +listhLast (_ `ConsKnown` sh@(_ `ConsKnown` _)) = listhLast sh +listhLast (x `ConsKnown` ZH) = SKnown x -- * Mixed shapes @@ -448,7 +476,7 @@ pattern (:$%) forall n sh. (n : sh ~ sh1) => SMayNat i n -> ShX sh i -> ShX sh1 i pattern i :$% shl <- ShX (listhUncons -> Just (UnconsListHRes (ShX -> shl) i)) - where i :$% ShX shl = ShX (i ::# shl) + where i :$% ShX shl = case i of; SUnknown x -> ShX (x `ConsUnknown` shl); SKnown x -> ShX (x `ConsKnown` shl) infixr 3 :$% {-# COMPLETE ZSX, (:$%) #-} @@ -468,8 +496,8 @@ instance Functor (ShX sh) where instance NFData i => NFData (ShX sh i) where rnf (ShX ZH) = () - rnf (ShX (SUnknown i ::# l)) = rnf i `seq` rnf (ShX l) - rnf (ShX (SKnown SNat ::# l)) = rnf (ShX l) + rnf (ShX (i `ConsUnknown` l)) = rnf i `seq` rnf (ShX l) + rnf (ShX (SNat `ConsKnown` l)) = rnf (ShX l) -- | This checks only whether the types are equal; unknown dimensions might -- still differ. This corresponds to 'testEquality', except on the penultimate @@ -561,7 +589,7 @@ shxDropSSX = coerce (listhDrop @i @()) shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i shxDropIx (IxX ZX) long = long -shxDropIx (IxX (_ ::% short)) long = case long of ShX (_ ::# long') -> shxDropIx (IxX short) (ShX long') +shxDropIx (IxX (_ ::% short)) long = case long of _ :$% long' -> shxDropIx (IxX short) long' shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i shxDropSh = coerce (listhDrop @i @i) @@ -638,7 +666,8 @@ pattern (:!%) forall n sh. (n : sh ~ sh1) => SMayNat () n -> StaticShX sh -> StaticShX sh1 pattern i :!% shl <- StaticShX (listhUncons -> Just (UnconsListHRes (StaticShX -> shl) i)) - where i :!% StaticShX shl = StaticShX (i ::# shl) + where i :!% StaticShX shl = case i of; SUnknown () -> StaticShX (() `ConsUnknown` shl); SKnown x -> StaticShX (x `ConsKnown` shl) + infixr 3 :!% {-# COMPLETE ZKX, (:!%) #-} @@ -652,8 +681,8 @@ instance Show (StaticShX sh) where instance NFData (StaticShX sh) where rnf (StaticShX ZH) = () - rnf (StaticShX (SUnknown () ::# l)) = rnf (StaticShX l) - rnf (StaticShX (SKnown SNat ::# l)) = rnf (StaticShX l) + rnf (StaticShX (() `ConsUnknown` l)) = rnf (StaticShX l) + rnf (StaticShX (SNat `ConsKnown` l)) = rnf (StaticShX l) instance TestEquality StaticShX where testEquality (StaticShX l1) (StaticShX l2) = listhEqType l1 l2 @@ -680,11 +709,11 @@ ssxTail (_ :!% ssh) = ssh ssxTakeIx :: forall sh sh' i. Proxy sh' -> IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh ssxTakeIx _ (IxX ZX) _ = ZKX -ssxTakeIx proxy (IxX (_ ::% long)) short = case short of StaticShX (i ::# short') -> i :!% ssxTakeIx proxy (IxX long) (StaticShX short') +ssxTakeIx proxy (IxX (_ ::% long)) short = case short of i :!% short' -> i :!% ssxTakeIx proxy (IxX long) short' ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' ssxDropIx (IxX ZX) long = long -ssxDropIx (IxX (_ ::% short)) long = case long of StaticShX (_ ::# long') -> ssxDropIx (IxX short) (StaticShX long') +ssxDropIx (IxX (_ ::% short)) long = case long of _ :!% long' -> ssxDropIx (IxX short) long' ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' ssxDropSh = coerce (listhDrop @() @i) diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs index 93c46ed..2e0c1ca 100644 --- a/src/Data/Array/Nested/Permutation.hs +++ b/src/Data/Array/Nested/Permutation.hs @@ -174,23 +174,31 @@ type family DropLen ref l where listhTakeLen :: forall i is sh. Perm is -> ListH sh i -> ListH (TakeLen is sh) i listhTakeLen PNil _ = ZH -listhTakeLen (_ `PCons` is) (n ::# sh) = n ::# listhTakeLen is sh +listhTakeLen (_ `PCons` is) (n `ConsUnknown` sh) = n `ConsUnknown` listhTakeLen is sh +listhTakeLen (_ `PCons` is) (n `ConsKnown` sh) = n `ConsKnown` listhTakeLen is sh listhTakeLen (_ `PCons` _) ZH = error "Permutation longer than shape" listhDropLen :: forall i is sh. Perm is -> ListH sh i -> ListH (DropLen is sh) i listhDropLen PNil sh = sh -listhDropLen (_ `PCons` is) (_ ::# sh) = listhDropLen is sh +listhDropLen (_ `PCons` is) (_ `ConsUnknown` sh) = listhDropLen is sh +listhDropLen (_ `PCons` is) (_ `ConsKnown` sh) = listhDropLen is sh listhDropLen (_ `PCons` _) ZH = error "Permutation longer than shape" listhPermute :: forall i is sh. Perm is -> ListH sh i -> ListH (Permute is sh) i listhPermute PNil _ = ZH listhPermute (i `PCons` (is :: Perm is')) (sh :: ListH sh i) = - listhIndex i sh ::# listhPermute is sh + case listhIndex i sh of + SUnknown x -> x `ConsUnknown` listhPermute is sh + SKnown x -> x `ConsKnown` listhPermute is sh listhIndex :: forall i k sh. SNat k -> ListH sh i -> SMayNat i (Index k sh) -listhIndex SZ (n ::# _) = n -listhIndex (SS (i :: SNat k')) ((_ :: SMayNat i n) ::# (sh :: ListH sh' i)) - | Refl <- lemIndexSucc (Proxy @k') (Proxy @n) (Proxy @sh') +listhIndex SZ (n `ConsUnknown` _) = SUnknown n +listhIndex SZ (n `ConsKnown` _) = SKnown n +listhIndex (SS (i :: SNat k')) ((_ :: i) `ConsUnknown` (sh :: ListH sh' i)) + | Refl <- lemIndexSucc (Proxy @k') (Proxy @Nothing) (Proxy @sh') + = listhIndex i sh +listhIndex (SS (i :: SNat k')) ((_ :: SNat n) `ConsKnown` (sh :: ListH sh' i)) + | Refl <- lemIndexSucc (Proxy @k') (Proxy @(Just n)) (Proxy @sh') = listhIndex i sh listhIndex _ ZH = error "Index into empty shape" -- cgit v1.2.3-70-g09d2