diff options
| -rw-r--r-- | src/Data/Array/Mixed/Permutation.hs | 2 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Shape.hs | 63 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Shape.hs | 61 | 
3 files changed, 97 insertions, 29 deletions
| diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs index 015a828..8efcbe8 100644 --- a/src/Data/Array/Mixed/Permutation.hs +++ b/src/Data/Array/Mixed/Permutation.hs @@ -241,7 +241,7 @@ permInverse = \perm k ->      provePermInverse :: Perm is -> Perm is' -> StaticShX sh                       -> Maybe (Permute is' (Permute is sh) :~: sh)      provePermInverse perm perminv ssh = -      ssxGeq (ssxPermute perminv (ssxPermute perm ssh)) ssh +      ssxEqType (ssxPermute perminv (ssxPermute perm ssh)) ssh  type family MapSucc is where    MapSucc '[] = '[] diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs index e5f8b67..16f62fe 100644 --- a/src/Data/Array/Mixed/Shape.hs +++ b/src/Data/Array/Mixed/Shape.hs @@ -71,13 +71,28 @@ listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1)  listxUncons (i ::% shl') = Just (UnconsListXRes shl' i)  listxUncons ZX = Nothing -listxEq :: TestEquality f => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh') -listxEq ZX ZX = Just Refl -listxEq (n ::% sh) (m ::% sh') +-- | This checks only whether the types are equal; if the elements of the list +-- are not singletons, their values may still differ. This corresponds to +-- 'testEquality', except on the penultimate type parameter. +listxEqType :: TestEquality f => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh') +listxEqType ZX ZX = Just Refl +listxEqType (n ::% sh) (m ::% sh')    | Just Refl <- testEquality n m -  , Just Refl <- listxEq sh sh' +  , Just Refl <- listxEqType sh sh'    = Just Refl -listxEq _ _ = Nothing +listxEqType _ _ = Nothing + +-- | This checks whether the two lists actually contain equal values. This is +-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ +-- in the @some@ package (except on the penultimate type parameter). +listxEqual :: (TestEquality f, forall n. Eq (f n)) => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh') +listxEqual ZX ZX = Just Refl +listxEqual (n ::% sh) (m ::% sh') +  | Just Refl <- testEquality n m +  , n == m +  , Just Refl <- listxEqual sh sh' +  = Just Refl +listxEqual _ _ = Nothing  listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g  listxFmap _ ZX = ZX @@ -293,11 +308,26 @@ instance NFData i => NFData (ShX sh i) where  shxLength :: ShX sh i -> Int  shxLength (ShX l) = listxLength l -shxRank :: ShX sh f -> SNat (Rank sh) +shxRank :: ShX sh i -> SNat (Rank sh)  shxRank (ShX list) = listxRank list --- | This is more than @geq@: it also checks that the integers (the unknown --- dimensions) are the same. +-- | This checks only whether the types are equal; unknown dimensions might +-- still differ. This corresponds to 'testEquality', except on the penultimate +-- type parameter. +shxEqType :: ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') +shxEqType ZSX ZSX = Just Refl +shxEqType (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') +  | Just Refl <- sameNat n m +  , Just Refl <- shxEqType sh sh' +  = Just Refl +shxEqType (SUnknown _ :$% sh) (SUnknown _ :$% sh') +  | Just Refl <- shxEqType sh sh' +  = Just Refl +shxEqType _ _ = Nothing + +-- | This checks whether all dimensions have the same value. This is more than +-- 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ in the +-- @some@ package (except on the penultimate type parameter).  shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh')  shxEqual ZSX ZSX = Just Refl  shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') @@ -417,23 +447,14 @@ instance Show (StaticShX sh) where    showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l  instance TestEquality StaticShX where -  testEquality (StaticShX l1) (StaticShX l2) = listxEq l1 l2 +  testEquality (StaticShX l1) (StaticShX l2) = listxEqType l1 l2  ssxLength :: StaticShX sh -> Int  ssxLength (StaticShX l) = listxLength l --- | This suffices as an implementation of @geq@ in the @Data.GADT.Compare@ --- class of the @some@ package. -ssxGeq :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') -ssxGeq ZKX ZKX = Just Refl -ssxGeq (SKnown n@SNat :!% sh) (SKnown m@SNat :!% sh') -  | Just Refl <- sameNat n m -  , Just Refl <- ssxGeq sh sh' -  = Just Refl -ssxGeq (SUnknown () :!% sh) (SUnknown () :!% sh') -  | Just Refl <- ssxGeq sh sh' -  = Just Refl -ssxGeq _ _ = Nothing +-- | @ssxEqType = 'testEquality'@. Provided for consistency. +ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') +ssxEqType = testEquality  ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')  ssxAppend ZKX sh' = sh' diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs index 59f2c9a..5fb2e7f 100644 --- a/src/Data/Array/Nested/Internal/Shape.hs +++ b/src/Data/Array/Nested/Internal/Shape.hs @@ -65,6 +65,24 @@ listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1)  listrUncons (i ::: sh') = Just (UnconsListRRes sh' i)  listrUncons ZR = Nothing +-- | This checks only whether the ranks are equal, not whether the actual +-- values are. +listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n') +listrEqRank ZR ZR = Just Refl +listrEqRank (_ ::: sh) (_ ::: sh') +  | Just Refl <- listrEqRank sh sh' +  = Just Refl +listrEqRank _ _ = Nothing + +-- | This compares the lists for value equality. +listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n') +listrEqual ZR ZR = Just Refl +listrEqual (i ::: sh) (j ::: sh') +  | Just Refl <- listrEqual sh sh' +  , i == j +  = Just Refl +listrEqual _ _ = Nothing +  listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS  listrShow f l = showString "[" . go "" l . showString "]"    where @@ -207,7 +225,7 @@ pattern (:$:)       forall n. (n + 1 ~ n1)    => i -> ShR n i -> ShR n1 i  pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i)) -  where i :$: (ShR sh) = ShR (i ::: sh) +  where i :$: ShR sh = ShR (i ::: sh)  infixr 3 :$:  {-# COMPLETE ZSR, (:$:) #-} @@ -241,6 +259,15 @@ shCvtRX (n :$: (idx :: ShR m Int)) =    castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m))      (SUnknown n :$% shCvtRX idx) +-- | This checks only whether the ranks are equal, not whether the actual +-- values are. +shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') +shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh' + +-- | This compares the shapes for value equality. +shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n') +shrEqual (ShR sh) (ShR sh') = listrEqual sh sh' +  -- | The number of elements in an array described by this shape.  shrSize :: IShR n -> Int  shrSize ZSR = 1 @@ -316,13 +343,28 @@ listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1)  listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x)  listsUncons ZS = Nothing -listsEq :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') -listsEq ZS ZS = Just Refl -listsEq (n ::$ sh) (m ::$ sh') +-- | This checks only whether the types are equal; if the elements of the list +-- are not singletons, their values may still differ. This corresponds to +-- 'testEquality', except on the penultimate type parameter. +listsEqType :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') +listsEqType ZS ZS = Just Refl +listsEqType (n ::$ sh) (m ::$ sh')    | Just Refl <- testEquality n m -  , Just Refl <- listsEq sh sh' +  , Just Refl <- listsEqType sh sh'    = Just Refl -listsEq _ _ = Nothing +listsEqType _ _ = Nothing + +-- | This checks whether the two lists actually contain equal values. This is +-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ +-- in the @some@ package (except on the penultimate type parameter). +listsEqual :: (TestEquality f, forall n. Eq (f n)) => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') +listsEqual ZS ZS = Just Refl +listsEqual (n ::$ sh) (m ::$ sh') +  | Just Refl <- testEquality n m +  , n == m +  , Just Refl <- listsEqual sh sh' +  = Just Refl +listsEqual _ _ = Nothing  listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g  listsFmap _ ZS = ZS @@ -484,7 +526,12 @@ instance Show (ShS sh) where    showsPrec _ (ShS l) = listsShow (shows . fromSNat) l  instance TestEquality ShS where -  testEquality (ShS l1) (ShS l2) = listsEq l1 l2 +  testEquality (ShS l1) (ShS l2) = listsEqType l1 l2 + +-- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are +-- equal if and only if values are equal.) +shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh') +shsEqual = testEquality  shsLength :: ShS sh -> Int  shsLength (ShS l) = getSum (listsFold (\_ -> Sum 1) l) | 
