From 8e7f9981f9b8ca17bb3c46c942116eaa4f7cb0d3 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Tue, 28 May 2024 13:17:35 +0200 Subject: invertPermutation: Provide Permutation evidence --- src/Data/Array/Mixed.hs | 64 +++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 57 insertions(+), 7 deletions(-) (limited to 'src') diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 065756d..24a8482 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -39,6 +39,7 @@ import Data.Monoid (Sum(..)) import Data.Proxy import Data.Type.Bool import Data.Type.Equality +import Data.Type.Ord import qualified Data.Vector.Storable as VS import Foreign.Storable (Storable) import GHC.Generics (Generic) @@ -715,6 +716,10 @@ foldHList :: Monoid m => (forall a. f a -> m) -> HList f list -> m foldHList _ HNil = mempty foldHList f (x `HCons` l) = f x <> foldHList f l +snatLengthHList :: HList f list -> SNat (Rank list) +snatLengthHList HNil = SNat +snatLengthHList (_ `HCons` l) | SNat <- snatLengthHList l = SNat + type family TakeLen ref l where TakeLen '[] l = '[] TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs @@ -782,17 +787,23 @@ shPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) -- TODO: test this thing more properly invertPermutation :: HList SNat is -> (forall is'. - HList SNat is' + Permutation is' + => HList SNat is' -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh) -> r) -> r invertPermutation = \perm k -> - genPerm perm $ \invperm -> - k invperm - (\ssh -> case provePermInverse perm invperm ssh of - Just eq -> eq - Nothing -> error $ "invertPermutation: did not generate inverse? perm = " ++ show perm - ++ " ; invperm = " ++ show invperm) + genPerm perm $ \(invperm :: HList SNat is') -> + let sn = snatLengthHList invperm + in case (provePerm1 (Proxy @is') sn invperm, provePerm2 (SNat @0) sn invperm) of + (Just Refl, Just Refl) -> + k invperm + (\ssh -> case provePermInverse perm invperm ssh of + Just eq -> eq + Nothing -> error $ "invertPermutation: did not generate inverse? perm = " ++ show perm + ++ " ; invperm = " ++ show invperm) + _ -> error $ "invertPermutation: did not generate permutation? perm = " ++ show perm + ++ " ; invperm = " ++ show invperm where genPerm :: HList SNat is -> (forall is'. HList SNat is' -> r) -> r genPerm perm = @@ -803,6 +814,45 @@ invertPermutation = \perm k -> toHList [] k = k HNil toHList (n : ns) k = toHList ns $ \l -> TypeNats.withSomeSNat n $ \sn -> k (HCons sn l) + lemElemCount :: (0 <= n, Compare n m ~ LT) => proxy n -> proxy m -> Elem n (Count 0 m) :~: True + lemElemCount _ _ = unsafeCoerce Refl + + lemCount :: (OrdCond (Compare i n) True False True ~ True) => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n + lemCount _ _ = unsafeCoerce Refl + + lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True + lemElem _ _ = unsafeCoerce Refl + + provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> HList SNat is' + -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True) + provePerm1 _ _ HNil = Just (Refl) + provePerm1 p rtop@SNat (HCons sn@SNat perm) + | Just Refl <- provePerm1 p rtop perm + = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of + (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl + (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl + _ -> Nothing + | otherwise + = Nothing + + provePerm2 :: SNat i -> SNat n -> HList SNat is' -> Maybe (AllElem' (Count i n) is' :~: True) + provePerm2 = \i@(SNat :: SNat i) n@SNat perm -> + case cmpNat i n of + EQI -> Just Refl + LTI | Refl <- lemCount i n + , Just Refl <- provePerm2 (SNat @(i + 1)) n perm + -> checkElem i perm + | otherwise -> Nothing + GTI -> error "unreachable" + where + checkElem :: SNat i -> HList SNat is' -> Maybe (Elem i is' :~: True) + checkElem _ HNil = Nothing + checkElem i@SNat (HCons k@SNat perm :: HList SNat is') = + case sameNat i k of + Just Refl -> Just Refl + Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl + | otherwise -> Nothing + provePermInverse :: HList SNat is -> HList SNat is' -> StaticShX sh -> Maybe (Permute is' (Permute is sh) :~: sh) provePermInverse perm perminv ssh = geqStaticShX (ssxPermute perminv (ssxPermute perm ssh)) ssh -- cgit v1.2.3-70-g09d2