diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2026-04-04 10:33:50 +0200 |
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2026-04-04 10:33:50 +0200 |
| commit | a9ac62f66e45e64f83043e0ebda04f0b4b80b913 (patch) | |
| tree | 4de2974a7753e97c1f1040af72f49af904ad9570 | |
| parent | 2095a851760b6bb44ba92b70df1efceff1bad267 (diff) | |
Make ranked and shaped lists newtypes over mixed
| -rw-r--r-- | src/Data/Array/Nested/Lemmas.hs | 6 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 6 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 77 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 64 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Types.hs | 5 |
5 files changed, 75 insertions, 83 deletions
diff --git a/src/Data/Array/Nested/Lemmas.hs b/src/Data/Array/Nested/Lemmas.hs index fa5611b..e61b148 100644 --- a/src/Data/Array/Nested/Lemmas.hs +++ b/src/Data/Array/Nested/Lemmas.hs @@ -59,7 +59,7 @@ lemReplicatePlusApp _ _ _ = unsafeCoerceRefl lemReplicateEmpty :: proxy n -> Replicate n (Nothing @Nat) :~: '[] -> n :~: 0 lemReplicateEmpty _ Refl = unsafeCoerceRefl --- TODO: make less ad-hoc and rename these three: +-- TODO: make less ad-hoc and rename the following few: lemReplicateCons :: proxy sh -> proxy' n1 -> Nothing : sh :~: Replicate n1 Nothing -> n1 :~: Rank sh + 1 lemReplicateCons _ _ Refl = unsafeCoerceRefl @@ -70,6 +70,10 @@ lemReplicateSucc2 :: forall n1 n proxy. proxy n1 -> n + 1 :~: n1 -> Nothing @Nat : Replicate n Nothing :~: Replicate n1 Nothing lemReplicateSucc2 _ _ = unsafeCoerceRefl +-- TODO: simplify, but GHC doesn't consistently use congruence nor transitivity +lemReplicateHead :: proxy x -> proxy' sh -> proxy'' t -> proxy''' n -> x : sh :~: Replicate n t -> x :~: t +lemReplicateHead _ _ _ _ Refl = unsafeCoerceRefl + lemDropLenApp :: Rank l1 <= Rank l2 => Proxy l1 -> Proxy l2 -> Proxy rest -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest) diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index c2ab93f..5887f4e 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -75,12 +75,6 @@ instance NFData i => NFData (ListX sh i) where rnf ZX = () rnf (x ::% l) = rnf x `seq` rnf l -data UnconsListXRes i sh1 = - forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh i) i -listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1) -listxUncons (i ::% shl') = Just (UnconsListXRes shl' i) -listxUncons ZX = Nothing - instance Functor (ListX l) where {-# INLINE fmap #-} fmap _ ZX = ZX diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index b8b5a28..b0fb30d 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -16,6 +16,7 @@ {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} @@ -49,13 +50,37 @@ import Data.Array.Nested.Types type role ListR nominal representational type ListR :: Nat -> Type -> Type -data ListR n i where - ZR :: ListR 0 i - (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i -deriving instance Eq i => Eq (ListR n i) -deriving instance Ord i => Ord (ListR n i) +newtype ListR n i = ListR (ListX (Replicate n Nothing) i) + deriving (Eq, Ord, NFData, Functor, Foldable) + +pattern ZR :: forall n i. () => n ~ 0 => ListR n i +pattern ZR <- ListR (matchZX @n -> Just Refl) + where ZR = ListR ZX + +matchZX :: forall n i. ListX (Replicate n Nothing) i -> Maybe (n :~: 0) +matchZX ZX | Refl <- lemReplicateEmpty (Proxy @n) Refl = Just Refl +matchZX _ = Nothing + +pattern (:::) + :: forall {n1} {i}. + forall n. (n + 1 ~ n1) + => i -> ListR n i -> ListR n1 i +pattern i ::: sh <- (listrUncons -> Just (UnconsListRRes sh i)) + where i ::: ListR sh | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = ListR (i ::% sh) infixr 3 ::: +data UnconsListRRes i n1 = + forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i +listrUncons :: forall n1 i. ListR n1 i -> Maybe (UnconsListRRes i n1) +listrUncons (ListR ((::%) @n' @sh' x sh')) + | Refl <- lemReplicateHead (Proxy @n') (Proxy @sh') (Proxy @Nothing) (Proxy @n1) Refl + , Refl <- lemReplicateCons (Proxy @sh') (Proxy @n1) Refl + , Refl <- lemReplicateCons2 (Proxy @sh') (Proxy @n1) Refl = + Just (UnconsListRRes (ListR @(Rank sh') sh') x) +listrUncons (ListR _) = Nothing + +{-# COMPLETE ZR, (:::) #-} + #ifdef OXAR_DEFAULT_SHOW_INSTANCES deriving instance Show i => Show (ListR n i) #else @@ -63,32 +88,6 @@ instance Show i => Show (ListR n i) where showsPrec _ = listrShow shows #endif -instance NFData i => NFData (ListR n i) where - rnf ZR = () - rnf (x ::: l) = rnf x `seq` rnf l - -instance Functor (ListR n) where - {-# INLINE fmap #-} - fmap _ ZR = ZR - fmap f (x ::: xs) = f x ::: fmap f xs - -instance Foldable (ListR n) where - {-# INLINE foldMap #-} - foldMap _ ZR = mempty - foldMap f (x ::: xs) = f x <> foldMap f xs - {-# INLINE foldr #-} - foldr _ z ZR = z - foldr f z (x ::: xs) = f x (foldr f z xs) - toList = listrToList - null ZR = False - null _ = True - -data UnconsListRRes i n1 = - forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i -listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1) -listrUncons (i ::: sh') = Just (UnconsListRRes sh' i) -listrUncons ZR = Nothing - -- | This checks only whether the ranks are equal, not whether the actual -- values are. listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n') @@ -122,7 +121,7 @@ listrRank :: ListR n i -> SNat n listrRank ZR = SNat listrRank (_ ::: sh) = snatSucc (listrRank sh) -listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i +listrAppend :: forall n m i. ListR n i -> ListR m i -> ListR (n + m) i listrAppend ZR sh = sh listrAppend (x ::: xs) sh = x ::: listrAppend xs sh @@ -185,7 +184,7 @@ listrSplitAt SZ sh = (ZR, sh) listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) listrSplitAt SS{} ZR = error "m' + 1 <= 0" -listrPermutePrefix :: forall i n. PermR -> ListR n i -> ListR n i +listrPermutePrefix :: forall n i. PermR -> ListR n i -> ListR n i listrPermutePrefix = \perm sh -> TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat -> case listrRank sh of { shlen@SNat -> @@ -273,7 +272,7 @@ ixrCast :: SNat n' -> IxR n i -> IxR n' i ixrCast n (IxR idx) = IxR (listrCastWithName "ixrCast" n idx) ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i -ixrAppend = coerce (listrAppend @_ @i) +ixrAppend = coerce (listrAppend @n @m @i) ixrZip :: IxR n i -> IxR n j -> IxR n (i, j) ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2 @@ -283,7 +282,7 @@ ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 ixrPermutePrefix :: forall n i. PermR -> IxR n i -> IxR n i -ixrPermutePrefix = coerce (listrPermutePrefix @i) +ixrPermutePrefix = coerce (listrPermutePrefix @n @i) -- | Given a multidimensional index, get the corresponding linear -- index into the buffer. @@ -332,9 +331,9 @@ pattern (:$:) :: forall {n1} {i}. forall n. (n + 1 ~ n1) => i -> ShR n i -> ShR n1 i -pattern i :$: shl <- (shrUncons -> Just (UnconsShRRes shl i)) - where i :$: ShR shl | Refl <- lemReplicateSucc2 (Proxy @n1) Refl - = ShR (SUnknown i :$% shl) +pattern i :$: sh <- (shrUncons -> Just (UnconsShRRes sh i)) + where i :$: ShR sh | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = ShR (SUnknown i :$% sh) +infixr 3 :$: data UnconsShRRes i n1 = forall n. (n + 1 ~ n1) => UnconsShRRes (ShR n i) i @@ -345,8 +344,6 @@ shrUncons (ShR (SUnknown x :$% (sh' :: ShX sh' i))) = Just (UnconsShRRes (ShR sh') x) shrUncons (ShR _) = Nothing -infixr 3 :$: - {-# COMPLETE ZSR, (:$:) #-} type IShR n = ShR n Int diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index d59f65c..97d1559 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -47,14 +47,35 @@ import Data.Array.Nested.Types type role ListS nominal representational type ListS :: [Nat] -> Type -> Type -data ListS sh i where - ZS :: ListS '[] i - (::$) :: forall n sh {i}. i -> ListS sh i -> ListS (n : sh) i -deriving instance Eq i => Eq (ListS sh i) -deriving instance Ord i => Ord (ListS sh i) +newtype ListS sh i = ListS (ListX (MapJust sh) i) + deriving (Eq, Ord, NFData, Functor, Foldable) + +pattern ZS :: forall sh i. () => sh ~ '[] => ListS sh i +pattern ZS <- ListS (matchZX -> Just Refl) + where ZS = ListS ZX + +matchZX :: forall sh i. ListX (MapJust sh) i -> Maybe (sh :~: '[]) +matchZX ZX | Refl <- lemMapJustEmpty @sh Refl = Just Refl +matchZX _ = Nothing +pattern (::$) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => i -> ListS sh i -> ListS sh1 i +pattern i ::$ sh <- (listsUncons -> Just (UnconsListSRes sh i)) + where i ::$ ListS sh = ListS (i ::% sh) infixr 3 ::$ +data UnconsListSRes i sh1 = + forall n sh. (n : sh ~ sh1) => UnconsListSRes (ListS sh i) i +listsUncons :: forall sh1 i. ListS sh1 i -> Maybe (UnconsListSRes i sh1) +listsUncons (ListS (x ::% sh')) | Refl <- lemMapJustHead (Proxy @sh1) + , Refl <- lemMapJustCons @sh1 Refl = + Just (UnconsListSRes (ListS sh') x) +listsUncons (ListS _) = Nothing + +{-# COMPLETE ZS, (::$) #-} + #ifdef OXAR_DEFAULT_SHOW_INSTANCES deriving instance Show i => Show (ListS sh i) #else @@ -62,16 +83,6 @@ instance Show i => Show (ListS sh i) where showsPrec _ = listsShow shows #endif -instance NFData i => NFData (ListS n i) where - rnf ZS = () - rnf (x ::$ l) = rnf x `seq` rnf l - -data UnconsListSRes i sh1 = - forall n sh. (n : sh ~ sh1) => UnconsListSRes (ListS sh i) i -listsUncons :: ListS sh1 i -> Maybe (UnconsListSRes i sh1) -listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x) -listsUncons ZS = Nothing - listsShow :: forall sh i. (i -> ShowS) -> ListS sh i -> ShowS listsShow f l = showString "[" . go "" l . showString "]" where @@ -79,22 +90,6 @@ listsShow f l = showString "[" . go "" l . showString "]" go _ ZS = id go prefix (x ::$ xs) = showString prefix . f x . go "," xs -instance Functor (ListS l) where - {-# INLINE fmap #-} - fmap _ ZS = ZS - fmap f (x ::$ xs) = f x ::$ fmap f xs - -instance Foldable (ListS l) where - {-# INLINE foldMap #-} - foldMap _ ZS = mempty - foldMap f (x ::$ xs) = f x <> foldMap f xs - {-# INLINE foldr #-} - foldr _ z ZS = z - foldr f z (x ::$ xs) = f x (foldr f z xs) - toList = listsToList - null ZS = False - null _ = True - listsLength :: ListS sh i -> Int listsLength = length @@ -315,8 +310,9 @@ pattern (:$$) :: forall {sh1}. forall n sh. (n : sh ~ sh1) => SNat n -> ShS sh -> ShS sh1 -pattern i :$$ shl <- (shsUncons -> Just (UnconsShSRes i shl)) - where i :$$ ShS shl = ShS (SKnown i :$% shl) +pattern i :$$ sh <- (shsUncons -> Just (UnconsShSRes i sh)) + where i :$$ ShS sh = ShS (SKnown i :$% sh) +infixr 3 :$$ data UnconsShSRes sh1 = forall n sh. (n : sh ~ sh1) => UnconsShSRes (SNat n) (ShS sh) @@ -326,8 +322,6 @@ shsUncons (ShS (SKnown x :$% sh')) = Just (UnconsShSRes x (ShS sh')) shsUncons (ShS _) = Nothing -infixr 3 :$$ - {-# COMPLETE ZSS, (:$$) #-} #ifdef OXAR_DEFAULT_SHOW_INSTANCES diff --git a/src/Data/Array/Nested/Types.hs b/src/Data/Array/Nested/Types.hs index 8bb5b85..ec1b3dc 100644 --- a/src/Data/Array/Nested/Types.hs +++ b/src/Data/Array/Nested/Types.hs @@ -30,7 +30,7 @@ module Data.Array.Nested.Types ( Replicate, lemReplicateSucc, MapJust, - lemMapJustEmpty, lemMapJustCons, + lemMapJustEmpty, lemMapJustCons, lemMapJustHead, Head, Tail, Init, @@ -123,6 +123,9 @@ lemMapJustEmpty Refl = unsafeCoerceRefl lemMapJustCons :: MapJust sh :~: Just n : sh' -> sh :~: n : Tail sh lemMapJustCons Refl = unsafeCoerceRefl +lemMapJustHead :: proxy sh1 -> Head (MapJust sh1) :~: Just (Head sh1) +lemMapJustHead _ = unsafeCoerceRefl + type family Head l where Head (x : _) = x |
