diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-05-28 13:17:35 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-28 13:17:35 +0200 | 
| commit | 8e7f9981f9b8ca17bb3c46c942116eaa4f7cb0d3 (patch) | |
| tree | eb8aac717b73b9e7d91f1c7da9383c7414099ac3 | |
| parent | 95544b35615f6714fbef914cb6f2935a088e4d06 (diff) | |
invertPermutation: Provide Permutation evidence
| -rw-r--r-- | src/Data/Array/Mixed.hs | 64 | 
1 files changed, 57 insertions, 7 deletions
| diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 065756d..24a8482 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -39,6 +39,7 @@ import Data.Monoid (Sum(..))  import Data.Proxy  import Data.Type.Bool  import Data.Type.Equality +import Data.Type.Ord  import qualified Data.Vector.Storable as VS  import Foreign.Storable (Storable)  import GHC.Generics (Generic) @@ -715,6 +716,10 @@ foldHList :: Monoid m => (forall a. f a -> m) -> HList f list -> m  foldHList _ HNil = mempty  foldHList f (x `HCons` l) = f x <> foldHList f l +snatLengthHList :: HList f list -> SNat (Rank list) +snatLengthHList HNil = SNat +snatLengthHList (_ `HCons` l) | SNat <- snatLengthHList l = SNat +  type family TakeLen ref l where    TakeLen '[] l = '[]    TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs @@ -782,17 +787,23 @@ shPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat))  -- TODO: test this thing more properly  invertPermutation :: HList SNat is                    -> (forall is'. -                           HList SNat is' +                           Permutation is' +                        => HList SNat is'                          -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh)                          -> r)                    -> r  invertPermutation = \perm k -> -  genPerm perm $ \invperm -> -    k invperm -      (\ssh -> case provePermInverse perm invperm ssh of -                 Just eq -> eq -                 Nothing -> error $ "invertPermutation: did not generate inverse? perm = " ++ show perm -                                    ++ " ; invperm = " ++ show invperm) +  genPerm perm $ \(invperm :: HList SNat is') -> +    let sn = snatLengthHList invperm +    in case (provePerm1 (Proxy @is') sn invperm, provePerm2 (SNat @0) sn invperm) of +         (Just Refl, Just Refl) -> +           k invperm +             (\ssh -> case provePermInverse perm invperm ssh of +                        Just eq -> eq +                        Nothing -> error $ "invertPermutation: did not generate inverse? perm = " ++ show perm +                                           ++ " ; invperm = " ++ show invperm) +         _ -> error $ "invertPermutation: did not generate permutation? perm = " ++ show perm +                      ++ " ; invperm = " ++ show invperm    where      genPerm :: HList SNat is -> (forall is'. HList SNat is' -> r) -> r      genPerm perm = @@ -803,6 +814,45 @@ invertPermutation = \perm k ->          toHList [] k = k HNil          toHList (n : ns) k = toHList ns $ \l -> TypeNats.withSomeSNat n $ \sn -> k (HCons sn l) +    lemElemCount :: (0 <= n, Compare n m ~ LT) => proxy n -> proxy m -> Elem n (Count 0 m) :~: True +    lemElemCount _ _ = unsafeCoerce Refl + +    lemCount :: (OrdCond (Compare i n) True False True ~ True) => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n +    lemCount _ _ = unsafeCoerce Refl + +    lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True +    lemElem _ _ = unsafeCoerce Refl + +    provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> HList SNat is' +               -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True) +    provePerm1 _ _ HNil = Just (Refl) +    provePerm1 p rtop@SNat (HCons sn@SNat perm) +      | Just Refl <- provePerm1 p rtop perm +      = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of +          (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl +          (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl +          _ -> Nothing +      | otherwise +      = Nothing + +    provePerm2 :: SNat i -> SNat n -> HList SNat is' -> Maybe (AllElem' (Count i n) is' :~: True) +    provePerm2 = \i@(SNat :: SNat i) n@SNat perm -> +      case cmpNat i n of +        EQI -> Just Refl +        LTI | Refl <- lemCount i n +            , Just Refl <- provePerm2 (SNat @(i + 1)) n perm +            -> checkElem i perm +            | otherwise -> Nothing +        GTI -> error "unreachable" +      where +        checkElem :: SNat i -> HList SNat is' -> Maybe (Elem i is' :~: True) +        checkElem _ HNil = Nothing +        checkElem i@SNat (HCons k@SNat perm :: HList SNat is') = +          case sameNat i k of +            Just Refl -> Just Refl +            Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl +                    | otherwise -> Nothing +      provePermInverse :: HList SNat is -> HList SNat is' -> StaticShX sh -> Maybe (Permute is' (Permute is sh) :~: sh)      provePermInverse perm perminv ssh = geqStaticShX (ssxPermute perminv (ssxPermute perm ssh)) ssh | 
