From 20626932cca57d0787e8464dcfd88944eb6336ec Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 3 Jun 2024 18:07:32 +0200 Subject: Separate permCheckPermutation from permInverse --- src/Data/Array/Mixed/Permutation.hs | 112 ++++++++++++++++++++---------------- 1 file changed, 62 insertions(+), 50 deletions(-) (limited to 'src') diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs index e1e5c44..7c77fc4 100644 --- a/src/Data/Array/Mixed/Permutation.hs +++ b/src/Data/Array/Mixed/Permutation.hs @@ -20,6 +20,7 @@ module Data.Array.Mixed.Permutation where import Data.Coerce (coerce) import Data.Functor.Const import Data.List (sort) +import Data.Maybe (fromMaybe) import Data.Proxy import Data.Type.Bool import Data.Type.Equality @@ -61,6 +62,58 @@ permToList (x `PCons` l) = TN.fromSNat x : permToList l permToList' :: Perm list -> [Int] permToList' = map fromIntegral . permToList +-- | When called as @permCheckPermutation p k@, if @p@ is a permutation of +-- @[0 .. 'length' ('permToList' p) - 1]@, @Just k@ is returned. If it isn't, +-- then @Nothing@ is returned. +permCheckPermutation :: forall list r. Perm list -> (IsPermutation list => r) -> Maybe r +permCheckPermutation = \p k -> + let n = permLengthSNat p + in case (provePerm1 (Proxy @list) n p, provePerm2 (SNat @0) n p) of + (Just Refl, Just Refl) -> Just k + _ -> Nothing + where + lemElemCount :: (0 <= n, Compare n m ~ LT) + => proxy n -> proxy m -> Elem n (Count 0 m) :~: True + lemElemCount _ _ = unsafeCoerceRefl + + lemCount :: (OrdCond (Compare i n) True False True ~ True) + => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n + lemCount _ _ = unsafeCoerceRefl + + lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True + lemElem _ _ = unsafeCoerceRefl + + provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is' + -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True) + provePerm1 _ _ PNil = Just (Refl) + provePerm1 p rtop@SNat (PCons sn@SNat perm) + | Just Refl <- provePerm1 p rtop perm + = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of + (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl + (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl + _ -> Nothing + | otherwise + = Nothing + + provePerm2 :: SNat i -> SNat n -> Perm is' + -> Maybe (AllElem' (Count i n) is' :~: True) + provePerm2 = \i@(SNat :: SNat i) n@SNat perm -> + case cmpNat i n of + EQI -> Just Refl + LTI | Refl <- lemCount i n + , Just Refl <- provePerm2 (SNat @(i + 1)) n perm + -> checkElem i perm + | otherwise -> Nothing + GTI -> error "unreachable" + where + checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True) + checkElem _ PNil = Nothing + checkElem i@SNat (PCons k@SNat perm :: Perm is') = + case sameNat i k of + Just Refl -> Just Refl + Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl + | otherwise -> Nothing + -- | Utility class for generating permutations from type class information. class KnownPerm l where makePerm :: Perm l instance KnownPerm '[] where makePerm = PNil @@ -167,16 +220,15 @@ permInverse :: Perm is -> r permInverse = \perm k -> genPerm perm $ \(invperm :: Perm is') -> - let sn = permLengthSNat invperm - in case (provePerm1 (Proxy @is') sn invperm, provePerm2 (SNat @0) sn invperm) of - (Just Refl, Just Refl) -> - k invperm - (\ssh -> case provePermInverse perm invperm ssh of - Just eq -> eq - Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm - ++ " ; invperm = " ++ show invperm) - _ -> error $ "permInverse: did not generate permutation? perm = " ++ show perm - ++ " ; invperm = " ++ show invperm + fromMaybe + (error $ "permInverse: did not generate permutation? perm = " ++ show perm + ++ " ; invperm = " ++ show invperm) + (permCheckPermutation invperm + (k invperm + (\ssh -> case provePermInverse perm invperm ssh of + Just eq -> eq + Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm + ++ " ; invperm = " ++ show invperm))) where genPerm :: Perm is -> (forall is'. Perm is' -> r) -> r genPerm perm = @@ -187,46 +239,6 @@ permInverse = \perm k -> toHList [] k = k PNil toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l) - lemElemCount :: (0 <= n, Compare n m ~ LT) => proxy n -> proxy m -> Elem n (Count 0 m) :~: True - lemElemCount _ _ = unsafeCoerceRefl - - lemCount :: (OrdCond (Compare i n) True False True ~ True) => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n - lemCount _ _ = unsafeCoerceRefl - - lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True - lemElem _ _ = unsafeCoerceRefl - - provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is' - -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True) - provePerm1 _ _ PNil = Just (Refl) - provePerm1 p rtop@SNat (PCons sn@SNat perm) - | Just Refl <- provePerm1 p rtop perm - = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of - (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl - (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl - _ -> Nothing - | otherwise - = Nothing - - provePerm2 :: SNat i -> SNat n -> Perm is' - -> Maybe (AllElem' (Count i n) is' :~: True) - provePerm2 = \i@(SNat :: SNat i) n@SNat perm -> - case cmpNat i n of - EQI -> Just Refl - LTI | Refl <- lemCount i n - , Just Refl <- provePerm2 (SNat @(i + 1)) n perm - -> checkElem i perm - | otherwise -> Nothing - GTI -> error "unreachable" - where - checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True) - checkElem _ PNil = Nothing - checkElem i@SNat (PCons k@SNat perm :: Perm is') = - case sameNat i k of - Just Refl -> Just Refl - Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl - | otherwise -> Nothing - provePermInverse :: Perm is -> Perm is' -> StaticShX sh -> Maybe (Permute is' (Permute is sh) :~: sh) provePermInverse perm perminv ssh = -- cgit v1.2.3-70-g09d2