aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-06-03 18:07:32 +0200
committerTom Smeding <tom@tomsmeding.com>2024-06-03 18:07:32 +0200
commit20626932cca57d0787e8464dcfd88944eb6336ec (patch)
tree142bcc797a1effd64937f3204b821714bf702863 /src/Data/Array/Mixed
parent0cde8fb6cf80f3606ece7b47981ff017eb90d00c (diff)
Separate permCheckPermutation from permInverse
Diffstat (limited to 'src/Data/Array/Mixed')
-rw-r--r--src/Data/Array/Mixed/Permutation.hs112
1 files changed, 62 insertions, 50 deletions
diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs
index e1e5c44..7c77fc4 100644
--- a/src/Data/Array/Mixed/Permutation.hs
+++ b/src/Data/Array/Mixed/Permutation.hs
@@ -20,6 +20,7 @@ module Data.Array.Mixed.Permutation where
import Data.Coerce (coerce)
import Data.Functor.Const
import Data.List (sort)
+import Data.Maybe (fromMaybe)
import Data.Proxy
import Data.Type.Bool
import Data.Type.Equality
@@ -61,6 +62,58 @@ permToList (x `PCons` l) = TN.fromSNat x : permToList l
permToList' :: Perm list -> [Int]
permToList' = map fromIntegral . permToList
+-- | When called as @permCheckPermutation p k@, if @p@ is a permutation of
+-- @[0 .. 'length' ('permToList' p) - 1]@, @Just k@ is returned. If it isn't,
+-- then @Nothing@ is returned.
+permCheckPermutation :: forall list r. Perm list -> (IsPermutation list => r) -> Maybe r
+permCheckPermutation = \p k ->
+ let n = permLengthSNat p
+ in case (provePerm1 (Proxy @list) n p, provePerm2 (SNat @0) n p) of
+ (Just Refl, Just Refl) -> Just k
+ _ -> Nothing
+ where
+ lemElemCount :: (0 <= n, Compare n m ~ LT)
+ => proxy n -> proxy m -> Elem n (Count 0 m) :~: True
+ lemElemCount _ _ = unsafeCoerceRefl
+
+ lemCount :: (OrdCond (Compare i n) True False True ~ True)
+ => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n
+ lemCount _ _ = unsafeCoerceRefl
+
+ lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True
+ lemElem _ _ = unsafeCoerceRefl
+
+ provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is'
+ -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True)
+ provePerm1 _ _ PNil = Just (Refl)
+ provePerm1 p rtop@SNat (PCons sn@SNat perm)
+ | Just Refl <- provePerm1 p rtop perm
+ = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of
+ (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
+ (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
+ _ -> Nothing
+ | otherwise
+ = Nothing
+
+ provePerm2 :: SNat i -> SNat n -> Perm is'
+ -> Maybe (AllElem' (Count i n) is' :~: True)
+ provePerm2 = \i@(SNat :: SNat i) n@SNat perm ->
+ case cmpNat i n of
+ EQI -> Just Refl
+ LTI | Refl <- lemCount i n
+ , Just Refl <- provePerm2 (SNat @(i + 1)) n perm
+ -> checkElem i perm
+ | otherwise -> Nothing
+ GTI -> error "unreachable"
+ where
+ checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True)
+ checkElem _ PNil = Nothing
+ checkElem i@SNat (PCons k@SNat perm :: Perm is') =
+ case sameNat i k of
+ Just Refl -> Just Refl
+ Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl
+ | otherwise -> Nothing
+
-- | Utility class for generating permutations from type class information.
class KnownPerm l where makePerm :: Perm l
instance KnownPerm '[] where makePerm = PNil
@@ -167,16 +220,15 @@ permInverse :: Perm is
-> r
permInverse = \perm k ->
genPerm perm $ \(invperm :: Perm is') ->
- let sn = permLengthSNat invperm
- in case (provePerm1 (Proxy @is') sn invperm, provePerm2 (SNat @0) sn invperm) of
- (Just Refl, Just Refl) ->
- k invperm
- (\ssh -> case provePermInverse perm invperm ssh of
- Just eq -> eq
- Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm
- ++ " ; invperm = " ++ show invperm)
- _ -> error $ "permInverse: did not generate permutation? perm = " ++ show perm
- ++ " ; invperm = " ++ show invperm
+ fromMaybe
+ (error $ "permInverse: did not generate permutation? perm = " ++ show perm
+ ++ " ; invperm = " ++ show invperm)
+ (permCheckPermutation invperm
+ (k invperm
+ (\ssh -> case provePermInverse perm invperm ssh of
+ Just eq -> eq
+ Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm
+ ++ " ; invperm = " ++ show invperm)))
where
genPerm :: Perm is -> (forall is'. Perm is' -> r) -> r
genPerm perm =
@@ -187,46 +239,6 @@ permInverse = \perm k ->
toHList [] k = k PNil
toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l)
- lemElemCount :: (0 <= n, Compare n m ~ LT) => proxy n -> proxy m -> Elem n (Count 0 m) :~: True
- lemElemCount _ _ = unsafeCoerceRefl
-
- lemCount :: (OrdCond (Compare i n) True False True ~ True) => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n
- lemCount _ _ = unsafeCoerceRefl
-
- lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True
- lemElem _ _ = unsafeCoerceRefl
-
- provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is'
- -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True)
- provePerm1 _ _ PNil = Just (Refl)
- provePerm1 p rtop@SNat (PCons sn@SNat perm)
- | Just Refl <- provePerm1 p rtop perm
- = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of
- (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
- (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
- _ -> Nothing
- | otherwise
- = Nothing
-
- provePerm2 :: SNat i -> SNat n -> Perm is'
- -> Maybe (AllElem' (Count i n) is' :~: True)
- provePerm2 = \i@(SNat :: SNat i) n@SNat perm ->
- case cmpNat i n of
- EQI -> Just Refl
- LTI | Refl <- lemCount i n
- , Just Refl <- provePerm2 (SNat @(i + 1)) n perm
- -> checkElem i perm
- | otherwise -> Nothing
- GTI -> error "unreachable"
- where
- checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True)
- checkElem _ PNil = Nothing
- checkElem i@SNat (PCons k@SNat perm :: Perm is') =
- case sameNat i k of
- Just Refl -> Just Refl
- Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl
- | otherwise -> Nothing
-
provePermInverse :: Perm is -> Perm is' -> StaticShX sh
-> Maybe (Permute is' (Permute is sh) :~: sh)
provePermInverse perm perminv ssh =