aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-11 21:52:45 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-11 21:52:45 +0100
commita36d23048be6e2ad0e4516965f1e8b48756ef78b (patch)
treea21552fd2405debd7d93ca16b62e49477cc522b9 /src
parenteff6b7ba64fbe4e6e260ce3266109fd9fee27ae2 (diff)
More consistent equality functions on shapesHEADmaster
Diffstat (limited to 'src')
-rw-r--r--src/Data/Array/Mixed/Permutation.hs2
-rw-r--r--src/Data/Array/Mixed/Shape.hs63
-rw-r--r--src/Data/Array/Nested/Internal/Shape.hs61
3 files changed, 97 insertions, 29 deletions
diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs
index 015a828..8efcbe8 100644
--- a/src/Data/Array/Mixed/Permutation.hs
+++ b/src/Data/Array/Mixed/Permutation.hs
@@ -241,7 +241,7 @@ permInverse = \perm k ->
provePermInverse :: Perm is -> Perm is' -> StaticShX sh
-> Maybe (Permute is' (Permute is sh) :~: sh)
provePermInverse perm perminv ssh =
- ssxGeq (ssxPermute perminv (ssxPermute perm ssh)) ssh
+ ssxEqType (ssxPermute perminv (ssxPermute perm ssh)) ssh
type family MapSucc is where
MapSucc '[] = '[]
diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs
index e5f8b67..16f62fe 100644
--- a/src/Data/Array/Mixed/Shape.hs
+++ b/src/Data/Array/Mixed/Shape.hs
@@ -71,13 +71,28 @@ listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1)
listxUncons (i ::% shl') = Just (UnconsListXRes shl' i)
listxUncons ZX = Nothing
-listxEq :: TestEquality f => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh')
-listxEq ZX ZX = Just Refl
-listxEq (n ::% sh) (m ::% sh')
+-- | This checks only whether the types are equal; if the elements of the list
+-- are not singletons, their values may still differ. This corresponds to
+-- 'testEquality', except on the penultimate type parameter.
+listxEqType :: TestEquality f => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh')
+listxEqType ZX ZX = Just Refl
+listxEqType (n ::% sh) (m ::% sh')
| Just Refl <- testEquality n m
- , Just Refl <- listxEq sh sh'
+ , Just Refl <- listxEqType sh sh'
= Just Refl
-listxEq _ _ = Nothing
+listxEqType _ _ = Nothing
+
+-- | This checks whether the two lists actually contain equal values. This is
+-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@
+-- in the @some@ package (except on the penultimate type parameter).
+listxEqual :: (TestEquality f, forall n. Eq (f n)) => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh')
+listxEqual ZX ZX = Just Refl
+listxEqual (n ::% sh) (m ::% sh')
+ | Just Refl <- testEquality n m
+ , n == m
+ , Just Refl <- listxEqual sh sh'
+ = Just Refl
+listxEqual _ _ = Nothing
listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g
listxFmap _ ZX = ZX
@@ -293,11 +308,26 @@ instance NFData i => NFData (ShX sh i) where
shxLength :: ShX sh i -> Int
shxLength (ShX l) = listxLength l
-shxRank :: ShX sh f -> SNat (Rank sh)
+shxRank :: ShX sh i -> SNat (Rank sh)
shxRank (ShX list) = listxRank list
--- | This is more than @geq@: it also checks that the integers (the unknown
--- dimensions) are the same.
+-- | This checks only whether the types are equal; unknown dimensions might
+-- still differ. This corresponds to 'testEquality', except on the penultimate
+-- type parameter.
+shxEqType :: ShX sh i -> ShX sh' i -> Maybe (sh :~: sh')
+shxEqType ZSX ZSX = Just Refl
+shxEqType (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh')
+ | Just Refl <- sameNat n m
+ , Just Refl <- shxEqType sh sh'
+ = Just Refl
+shxEqType (SUnknown _ :$% sh) (SUnknown _ :$% sh')
+ | Just Refl <- shxEqType sh sh'
+ = Just Refl
+shxEqType _ _ = Nothing
+
+-- | This checks whether all dimensions have the same value. This is more than
+-- 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ in the
+-- @some@ package (except on the penultimate type parameter).
shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh')
shxEqual ZSX ZSX = Just Refl
shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh')
@@ -417,23 +447,14 @@ instance Show (StaticShX sh) where
showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l
instance TestEquality StaticShX where
- testEquality (StaticShX l1) (StaticShX l2) = listxEq l1 l2
+ testEquality (StaticShX l1) (StaticShX l2) = listxEqType l1 l2
ssxLength :: StaticShX sh -> Int
ssxLength (StaticShX l) = listxLength l
--- | This suffices as an implementation of @geq@ in the @Data.GADT.Compare@
--- class of the @some@ package.
-ssxGeq :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh')
-ssxGeq ZKX ZKX = Just Refl
-ssxGeq (SKnown n@SNat :!% sh) (SKnown m@SNat :!% sh')
- | Just Refl <- sameNat n m
- , Just Refl <- ssxGeq sh sh'
- = Just Refl
-ssxGeq (SUnknown () :!% sh) (SUnknown () :!% sh')
- | Just Refl <- ssxGeq sh sh'
- = Just Refl
-ssxGeq _ _ = Nothing
+-- | @ssxEqType = 'testEquality'@. Provided for consistency.
+ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh')
+ssxEqType = testEquality
ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
ssxAppend ZKX sh' = sh'
diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs
index 59f2c9a..5fb2e7f 100644
--- a/src/Data/Array/Nested/Internal/Shape.hs
+++ b/src/Data/Array/Nested/Internal/Shape.hs
@@ -65,6 +65,24 @@ listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1)
listrUncons (i ::: sh') = Just (UnconsListRRes sh' i)
listrUncons ZR = Nothing
+-- | This checks only whether the ranks are equal, not whether the actual
+-- values are.
+listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n')
+listrEqRank ZR ZR = Just Refl
+listrEqRank (_ ::: sh) (_ ::: sh')
+ | Just Refl <- listrEqRank sh sh'
+ = Just Refl
+listrEqRank _ _ = Nothing
+
+-- | This compares the lists for value equality.
+listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n')
+listrEqual ZR ZR = Just Refl
+listrEqual (i ::: sh) (j ::: sh')
+ | Just Refl <- listrEqual sh sh'
+ , i == j
+ = Just Refl
+listrEqual _ _ = Nothing
+
listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS
listrShow f l = showString "[" . go "" l . showString "]"
where
@@ -207,7 +225,7 @@ pattern (:$:)
forall n. (n + 1 ~ n1)
=> i -> ShR n i -> ShR n1 i
pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i))
- where i :$: (ShR sh) = ShR (i ::: sh)
+ where i :$: ShR sh = ShR (i ::: sh)
infixr 3 :$:
{-# COMPLETE ZSR, (:$:) #-}
@@ -241,6 +259,15 @@ shCvtRX (n :$: (idx :: ShR m Int)) =
castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m))
(SUnknown n :$% shCvtRX idx)
+-- | This checks only whether the ranks are equal, not whether the actual
+-- values are.
+shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n')
+shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh'
+
+-- | This compares the shapes for value equality.
+shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n')
+shrEqual (ShR sh) (ShR sh') = listrEqual sh sh'
+
-- | The number of elements in an array described by this shape.
shrSize :: IShR n -> Int
shrSize ZSR = 1
@@ -316,13 +343,28 @@ listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1)
listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x)
listsUncons ZS = Nothing
-listsEq :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
-listsEq ZS ZS = Just Refl
-listsEq (n ::$ sh) (m ::$ sh')
+-- | This checks only whether the types are equal; if the elements of the list
+-- are not singletons, their values may still differ. This corresponds to
+-- 'testEquality', except on the penultimate type parameter.
+listsEqType :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
+listsEqType ZS ZS = Just Refl
+listsEqType (n ::$ sh) (m ::$ sh')
+ | Just Refl <- testEquality n m
+ , Just Refl <- listsEqType sh sh'
+ = Just Refl
+listsEqType _ _ = Nothing
+
+-- | This checks whether the two lists actually contain equal values. This is
+-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@
+-- in the @some@ package (except on the penultimate type parameter).
+listsEqual :: (TestEquality f, forall n. Eq (f n)) => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
+listsEqual ZS ZS = Just Refl
+listsEqual (n ::$ sh) (m ::$ sh')
| Just Refl <- testEquality n m
- , Just Refl <- listsEq sh sh'
+ , n == m
+ , Just Refl <- listsEqual sh sh'
= Just Refl
-listsEq _ _ = Nothing
+listsEqual _ _ = Nothing
listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g
listsFmap _ ZS = ZS
@@ -484,7 +526,12 @@ instance Show (ShS sh) where
showsPrec _ (ShS l) = listsShow (shows . fromSNat) l
instance TestEquality ShS where
- testEquality (ShS l1) (ShS l2) = listsEq l1 l2
+ testEquality (ShS l1) (ShS l2) = listsEqType l1 l2
+
+-- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are
+-- equal if and only if values are equal.)
+shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh')
+shsEqual = testEquality
shsLength :: ShS sh -> Int
shsLength (ShS l) = getSum (listsFold (\_ -> Sum 1) l)