aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMikolaj Konarski <mikolaj.konarski@funktory.com>2026-04-04 10:33:50 +0200
committerMikolaj Konarski <mikolaj.konarski@funktory.com>2026-04-04 10:33:50 +0200
commita9ac62f66e45e64f83043e0ebda04f0b4b80b913 (patch)
tree4de2974a7753e97c1f1040af72f49af904ad9570
parent2095a851760b6bb44ba92b70df1efceff1bad267 (diff)
Make ranked and shaped lists newtypes over mixed
-rw-r--r--src/Data/Array/Nested/Lemmas.hs6
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs6
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs77
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs64
-rw-r--r--src/Data/Array/Nested/Types.hs5
5 files changed, 75 insertions, 83 deletions
diff --git a/src/Data/Array/Nested/Lemmas.hs b/src/Data/Array/Nested/Lemmas.hs
index fa5611b..e61b148 100644
--- a/src/Data/Array/Nested/Lemmas.hs
+++ b/src/Data/Array/Nested/Lemmas.hs
@@ -59,7 +59,7 @@ lemReplicatePlusApp _ _ _ = unsafeCoerceRefl
lemReplicateEmpty :: proxy n -> Replicate n (Nothing @Nat) :~: '[] -> n :~: 0
lemReplicateEmpty _ Refl = unsafeCoerceRefl
--- TODO: make less ad-hoc and rename these three:
+-- TODO: make less ad-hoc and rename the following few:
lemReplicateCons :: proxy sh -> proxy' n1 -> Nothing : sh :~: Replicate n1 Nothing -> n1 :~: Rank sh + 1
lemReplicateCons _ _ Refl = unsafeCoerceRefl
@@ -70,6 +70,10 @@ lemReplicateSucc2 :: forall n1 n proxy.
proxy n1 -> n + 1 :~: n1 -> Nothing @Nat : Replicate n Nothing :~: Replicate n1 Nothing
lemReplicateSucc2 _ _ = unsafeCoerceRefl
+-- TODO: simplify, but GHC doesn't consistently use congruence nor transitivity
+lemReplicateHead :: proxy x -> proxy' sh -> proxy'' t -> proxy''' n -> x : sh :~: Replicate n t -> x :~: t
+lemReplicateHead _ _ _ _ Refl = unsafeCoerceRefl
+
lemDropLenApp :: Rank l1 <= Rank l2
=> Proxy l1 -> Proxy l2 -> Proxy rest
-> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest)
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index c2ab93f..5887f4e 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -75,12 +75,6 @@ instance NFData i => NFData (ListX sh i) where
rnf ZX = ()
rnf (x ::% l) = rnf x `seq` rnf l
-data UnconsListXRes i sh1 =
- forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh i) i
-listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1)
-listxUncons (i ::% shl') = Just (UnconsListXRes shl' i)
-listxUncons ZX = Nothing
-
instance Functor (ListX l) where
{-# INLINE fmap #-}
fmap _ ZX = ZX
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index b8b5a28..b0fb30d 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -16,6 +16,7 @@
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
@@ -49,13 +50,37 @@ import Data.Array.Nested.Types
type role ListR nominal representational
type ListR :: Nat -> Type -> Type
-data ListR n i where
- ZR :: ListR 0 i
- (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
-deriving instance Eq i => Eq (ListR n i)
-deriving instance Ord i => Ord (ListR n i)
+newtype ListR n i = ListR (ListX (Replicate n Nothing) i)
+ deriving (Eq, Ord, NFData, Functor, Foldable)
+
+pattern ZR :: forall n i. () => n ~ 0 => ListR n i
+pattern ZR <- ListR (matchZX @n -> Just Refl)
+ where ZR = ListR ZX
+
+matchZX :: forall n i. ListX (Replicate n Nothing) i -> Maybe (n :~: 0)
+matchZX ZX | Refl <- lemReplicateEmpty (Proxy @n) Refl = Just Refl
+matchZX _ = Nothing
+
+pattern (:::)
+ :: forall {n1} {i}.
+ forall n. (n + 1 ~ n1)
+ => i -> ListR n i -> ListR n1 i
+pattern i ::: sh <- (listrUncons -> Just (UnconsListRRes sh i))
+ where i ::: ListR sh | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = ListR (i ::% sh)
infixr 3 :::
+data UnconsListRRes i n1 =
+ forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i
+listrUncons :: forall n1 i. ListR n1 i -> Maybe (UnconsListRRes i n1)
+listrUncons (ListR ((::%) @n' @sh' x sh'))
+ | Refl <- lemReplicateHead (Proxy @n') (Proxy @sh') (Proxy @Nothing) (Proxy @n1) Refl
+ , Refl <- lemReplicateCons (Proxy @sh') (Proxy @n1) Refl
+ , Refl <- lemReplicateCons2 (Proxy @sh') (Proxy @n1) Refl =
+ Just (UnconsListRRes (ListR @(Rank sh') sh') x)
+listrUncons (ListR _) = Nothing
+
+{-# COMPLETE ZR, (:::) #-}
+
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
deriving instance Show i => Show (ListR n i)
#else
@@ -63,32 +88,6 @@ instance Show i => Show (ListR n i) where
showsPrec _ = listrShow shows
#endif
-instance NFData i => NFData (ListR n i) where
- rnf ZR = ()
- rnf (x ::: l) = rnf x `seq` rnf l
-
-instance Functor (ListR n) where
- {-# INLINE fmap #-}
- fmap _ ZR = ZR
- fmap f (x ::: xs) = f x ::: fmap f xs
-
-instance Foldable (ListR n) where
- {-# INLINE foldMap #-}
- foldMap _ ZR = mempty
- foldMap f (x ::: xs) = f x <> foldMap f xs
- {-# INLINE foldr #-}
- foldr _ z ZR = z
- foldr f z (x ::: xs) = f x (foldr f z xs)
- toList = listrToList
- null ZR = False
- null _ = True
-
-data UnconsListRRes i n1 =
- forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i
-listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1)
-listrUncons (i ::: sh') = Just (UnconsListRRes sh' i)
-listrUncons ZR = Nothing
-
-- | This checks only whether the ranks are equal, not whether the actual
-- values are.
listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n')
@@ -122,7 +121,7 @@ listrRank :: ListR n i -> SNat n
listrRank ZR = SNat
listrRank (_ ::: sh) = snatSucc (listrRank sh)
-listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i
+listrAppend :: forall n m i. ListR n i -> ListR m i -> ListR (n + m) i
listrAppend ZR sh = sh
listrAppend (x ::: xs) sh = x ::: listrAppend xs sh
@@ -185,7 +184,7 @@ listrSplitAt SZ sh = (ZR, sh)
listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
listrSplitAt SS{} ZR = error "m' + 1 <= 0"
-listrPermutePrefix :: forall i n. PermR -> ListR n i -> ListR n i
+listrPermutePrefix :: forall n i. PermR -> ListR n i -> ListR n i
listrPermutePrefix = \perm sh ->
TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat ->
case listrRank sh of { shlen@SNat ->
@@ -273,7 +272,7 @@ ixrCast :: SNat n' -> IxR n i -> IxR n' i
ixrCast n (IxR idx) = IxR (listrCastWithName "ixrCast" n idx)
ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i
-ixrAppend = coerce (listrAppend @_ @i)
+ixrAppend = coerce (listrAppend @n @m @i)
ixrZip :: IxR n i -> IxR n j -> IxR n (i, j)
ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2
@@ -283,7 +282,7 @@ ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k
ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2
ixrPermutePrefix :: forall n i. PermR -> IxR n i -> IxR n i
-ixrPermutePrefix = coerce (listrPermutePrefix @i)
+ixrPermutePrefix = coerce (listrPermutePrefix @n @i)
-- | Given a multidimensional index, get the corresponding linear
-- index into the buffer.
@@ -332,9 +331,9 @@ pattern (:$:)
:: forall {n1} {i}.
forall n. (n + 1 ~ n1)
=> i -> ShR n i -> ShR n1 i
-pattern i :$: shl <- (shrUncons -> Just (UnconsShRRes shl i))
- where i :$: ShR shl | Refl <- lemReplicateSucc2 (Proxy @n1) Refl
- = ShR (SUnknown i :$% shl)
+pattern i :$: sh <- (shrUncons -> Just (UnconsShRRes sh i))
+ where i :$: ShR sh | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = ShR (SUnknown i :$% sh)
+infixr 3 :$:
data UnconsShRRes i n1 =
forall n. (n + 1 ~ n1) => UnconsShRRes (ShR n i) i
@@ -345,8 +344,6 @@ shrUncons (ShR (SUnknown x :$% (sh' :: ShX sh' i)))
= Just (UnconsShRRes (ShR sh') x)
shrUncons (ShR _) = Nothing
-infixr 3 :$:
-
{-# COMPLETE ZSR, (:$:) #-}
type IShR n = ShR n Int
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index d59f65c..97d1559 100644
--- a/src/Data/Array/Nested/Shaped/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -47,14 +47,35 @@ import Data.Array.Nested.Types
type role ListS nominal representational
type ListS :: [Nat] -> Type -> Type
-data ListS sh i where
- ZS :: ListS '[] i
- (::$) :: forall n sh {i}. i -> ListS sh i -> ListS (n : sh) i
-deriving instance Eq i => Eq (ListS sh i)
-deriving instance Ord i => Ord (ListS sh i)
+newtype ListS sh i = ListS (ListX (MapJust sh) i)
+ deriving (Eq, Ord, NFData, Functor, Foldable)
+
+pattern ZS :: forall sh i. () => sh ~ '[] => ListS sh i
+pattern ZS <- ListS (matchZX -> Just Refl)
+ where ZS = ListS ZX
+
+matchZX :: forall sh i. ListX (MapJust sh) i -> Maybe (sh :~: '[])
+matchZX ZX | Refl <- lemMapJustEmpty @sh Refl = Just Refl
+matchZX _ = Nothing
+pattern (::$)
+ :: forall {sh1} {i}.
+ forall n sh. (n : sh ~ sh1)
+ => i -> ListS sh i -> ListS sh1 i
+pattern i ::$ sh <- (listsUncons -> Just (UnconsListSRes sh i))
+ where i ::$ ListS sh = ListS (i ::% sh)
infixr 3 ::$
+data UnconsListSRes i sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsListSRes (ListS sh i) i
+listsUncons :: forall sh1 i. ListS sh1 i -> Maybe (UnconsListSRes i sh1)
+listsUncons (ListS (x ::% sh')) | Refl <- lemMapJustHead (Proxy @sh1)
+ , Refl <- lemMapJustCons @sh1 Refl =
+ Just (UnconsListSRes (ListS sh') x)
+listsUncons (ListS _) = Nothing
+
+{-# COMPLETE ZS, (::$) #-}
+
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
deriving instance Show i => Show (ListS sh i)
#else
@@ -62,16 +83,6 @@ instance Show i => Show (ListS sh i) where
showsPrec _ = listsShow shows
#endif
-instance NFData i => NFData (ListS n i) where
- rnf ZS = ()
- rnf (x ::$ l) = rnf x `seq` rnf l
-
-data UnconsListSRes i sh1 =
- forall n sh. (n : sh ~ sh1) => UnconsListSRes (ListS sh i) i
-listsUncons :: ListS sh1 i -> Maybe (UnconsListSRes i sh1)
-listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x)
-listsUncons ZS = Nothing
-
listsShow :: forall sh i. (i -> ShowS) -> ListS sh i -> ShowS
listsShow f l = showString "[" . go "" l . showString "]"
where
@@ -79,22 +90,6 @@ listsShow f l = showString "[" . go "" l . showString "]"
go _ ZS = id
go prefix (x ::$ xs) = showString prefix . f x . go "," xs
-instance Functor (ListS l) where
- {-# INLINE fmap #-}
- fmap _ ZS = ZS
- fmap f (x ::$ xs) = f x ::$ fmap f xs
-
-instance Foldable (ListS l) where
- {-# INLINE foldMap #-}
- foldMap _ ZS = mempty
- foldMap f (x ::$ xs) = f x <> foldMap f xs
- {-# INLINE foldr #-}
- foldr _ z ZS = z
- foldr f z (x ::$ xs) = f x (foldr f z xs)
- toList = listsToList
- null ZS = False
- null _ = True
-
listsLength :: ListS sh i -> Int
listsLength = length
@@ -315,8 +310,9 @@ pattern (:$$)
:: forall {sh1}.
forall n sh. (n : sh ~ sh1)
=> SNat n -> ShS sh -> ShS sh1
-pattern i :$$ shl <- (shsUncons -> Just (UnconsShSRes i shl))
- where i :$$ ShS shl = ShS (SKnown i :$% shl)
+pattern i :$$ sh <- (shsUncons -> Just (UnconsShSRes i sh))
+ where i :$$ ShS sh = ShS (SKnown i :$% sh)
+infixr 3 :$$
data UnconsShSRes sh1 =
forall n sh. (n : sh ~ sh1) => UnconsShSRes (SNat n) (ShS sh)
@@ -326,8 +322,6 @@ shsUncons (ShS (SKnown x :$% sh'))
= Just (UnconsShSRes x (ShS sh'))
shsUncons (ShS _) = Nothing
-infixr 3 :$$
-
{-# COMPLETE ZSS, (:$$) #-}
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
diff --git a/src/Data/Array/Nested/Types.hs b/src/Data/Array/Nested/Types.hs
index 8bb5b85..ec1b3dc 100644
--- a/src/Data/Array/Nested/Types.hs
+++ b/src/Data/Array/Nested/Types.hs
@@ -30,7 +30,7 @@ module Data.Array.Nested.Types (
Replicate,
lemReplicateSucc,
MapJust,
- lemMapJustEmpty, lemMapJustCons,
+ lemMapJustEmpty, lemMapJustCons, lemMapJustHead,
Head,
Tail,
Init,
@@ -123,6 +123,9 @@ lemMapJustEmpty Refl = unsafeCoerceRefl
lemMapJustCons :: MapJust sh :~: Just n : sh' -> sh :~: n : Tail sh
lemMapJustCons Refl = unsafeCoerceRefl
+lemMapJustHead :: proxy sh1 -> Head (MapJust sh1) :~: Just (Head sh1)
+lemMapJustHead _ = unsafeCoerceRefl
+
type family Head l where
Head (x : _) = x