diff options
Diffstat (limited to 'src/Data/Array/Nested/Internal/Shape.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal/Shape.hs | 61 |
1 files changed, 54 insertions, 7 deletions
diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs index 59f2c9a..5fb2e7f 100644 --- a/src/Data/Array/Nested/Internal/Shape.hs +++ b/src/Data/Array/Nested/Internal/Shape.hs @@ -65,6 +65,24 @@ listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1) listrUncons (i ::: sh') = Just (UnconsListRRes sh' i) listrUncons ZR = Nothing +-- | This checks only whether the ranks are equal, not whether the actual +-- values are. +listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n') +listrEqRank ZR ZR = Just Refl +listrEqRank (_ ::: sh) (_ ::: sh') + | Just Refl <- listrEqRank sh sh' + = Just Refl +listrEqRank _ _ = Nothing + +-- | This compares the lists for value equality. +listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n') +listrEqual ZR ZR = Just Refl +listrEqual (i ::: sh) (j ::: sh') + | Just Refl <- listrEqual sh sh' + , i == j + = Just Refl +listrEqual _ _ = Nothing + listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS listrShow f l = showString "[" . go "" l . showString "]" where @@ -207,7 +225,7 @@ pattern (:$:) forall n. (n + 1 ~ n1) => i -> ShR n i -> ShR n1 i pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i)) - where i :$: (ShR sh) = ShR (i ::: sh) + where i :$: ShR sh = ShR (i ::: sh) infixr 3 :$: {-# COMPLETE ZSR, (:$:) #-} @@ -241,6 +259,15 @@ shCvtRX (n :$: (idx :: ShR m Int)) = castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (SUnknown n :$% shCvtRX idx) +-- | This checks only whether the ranks are equal, not whether the actual +-- values are. +shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') +shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh' + +-- | This compares the shapes for value equality. +shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n') +shrEqual (ShR sh) (ShR sh') = listrEqual sh sh' + -- | The number of elements in an array described by this shape. shrSize :: IShR n -> Int shrSize ZSR = 1 @@ -316,13 +343,28 @@ listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1) listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x) listsUncons ZS = Nothing -listsEq :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') -listsEq ZS ZS = Just Refl -listsEq (n ::$ sh) (m ::$ sh') +-- | This checks only whether the types are equal; if the elements of the list +-- are not singletons, their values may still differ. This corresponds to +-- 'testEquality', except on the penultimate type parameter. +listsEqType :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') +listsEqType ZS ZS = Just Refl +listsEqType (n ::$ sh) (m ::$ sh') + | Just Refl <- testEquality n m + , Just Refl <- listsEqType sh sh' + = Just Refl +listsEqType _ _ = Nothing + +-- | This checks whether the two lists actually contain equal values. This is +-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ +-- in the @some@ package (except on the penultimate type parameter). +listsEqual :: (TestEquality f, forall n. Eq (f n)) => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') +listsEqual ZS ZS = Just Refl +listsEqual (n ::$ sh) (m ::$ sh') | Just Refl <- testEquality n m - , Just Refl <- listsEq sh sh' + , n == m + , Just Refl <- listsEqual sh sh' = Just Refl -listsEq _ _ = Nothing +listsEqual _ _ = Nothing listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g listsFmap _ ZS = ZS @@ -484,7 +526,12 @@ instance Show (ShS sh) where showsPrec _ (ShS l) = listsShow (shows . fromSNat) l instance TestEquality ShS where - testEquality (ShS l1) (ShS l2) = listsEq l1 l2 + testEquality (ShS l1) (ShS l2) = listsEqType l1 l2 + +-- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are +-- equal if and only if values are equal.) +shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh') +shsEqual = testEquality shsLength :: ShS sh -> Int shsLength (ShS l) = getSum (listsFold (\_ -> Sum 1) l) |