diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-06-03 18:07:32 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-06-03 18:07:32 +0200 | 
| commit | 20626932cca57d0787e8464dcfd88944eb6336ec (patch) | |
| tree | 142bcc797a1effd64937f3204b821714bf702863 /src/Data/Array | |
| parent | 0cde8fb6cf80f3606ece7b47981ff017eb90d00c (diff) | |
Separate permCheckPermutation from permInverse
Diffstat (limited to 'src/Data/Array')
| -rw-r--r-- | src/Data/Array/Mixed/Permutation.hs | 112 | 
1 files changed, 62 insertions, 50 deletions
| diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs index e1e5c44..7c77fc4 100644 --- a/src/Data/Array/Mixed/Permutation.hs +++ b/src/Data/Array/Mixed/Permutation.hs @@ -20,6 +20,7 @@ module Data.Array.Mixed.Permutation where  import Data.Coerce (coerce)  import Data.Functor.Const  import Data.List (sort) +import Data.Maybe (fromMaybe)  import Data.Proxy  import Data.Type.Bool  import Data.Type.Equality @@ -61,6 +62,58 @@ permToList (x `PCons` l) = TN.fromSNat x : permToList l  permToList' :: Perm list -> [Int]  permToList' = map fromIntegral . permToList +-- | When called as @permCheckPermutation p k@, if @p@ is a permutation of +-- @[0 .. 'length' ('permToList' p) - 1]@, @Just k@ is returned. If it isn't, +-- then @Nothing@ is returned. +permCheckPermutation :: forall list r. Perm list -> (IsPermutation list => r) -> Maybe r +permCheckPermutation = \p k -> +  let n = permLengthSNat p +  in case (provePerm1 (Proxy @list) n p, provePerm2 (SNat @0) n p) of +       (Just Refl, Just Refl) -> Just k +       _ -> Nothing +  where +    lemElemCount :: (0 <= n, Compare n m ~ LT) +                 => proxy n -> proxy m -> Elem n (Count 0 m) :~: True +    lemElemCount _ _ = unsafeCoerceRefl + +    lemCount :: (OrdCond (Compare i n) True False True ~ True) +             => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n +    lemCount _ _ = unsafeCoerceRefl + +    lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True +    lemElem _ _ = unsafeCoerceRefl + +    provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is' +               -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True) +    provePerm1 _ _ PNil = Just (Refl) +    provePerm1 p rtop@SNat (PCons sn@SNat perm) +      | Just Refl <- provePerm1 p rtop perm +      = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of +          (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl +          (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl +          _ -> Nothing +      | otherwise +      = Nothing + +    provePerm2 :: SNat i -> SNat n -> Perm is' +               -> Maybe (AllElem' (Count i n) is' :~: True) +    provePerm2 = \i@(SNat :: SNat i) n@SNat perm -> +      case cmpNat i n of +        EQI -> Just Refl +        LTI | Refl <- lemCount i n +            , Just Refl <- provePerm2 (SNat @(i + 1)) n perm +            -> checkElem i perm +            | otherwise -> Nothing +        GTI -> error "unreachable" +      where +        checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True) +        checkElem _ PNil = Nothing +        checkElem i@SNat (PCons k@SNat perm :: Perm is') = +          case sameNat i k of +            Just Refl -> Just Refl +            Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl +                    | otherwise -> Nothing +  -- | Utility class for generating permutations from type class information.  class KnownPerm l where makePerm :: Perm l  instance KnownPerm '[] where makePerm = PNil @@ -167,16 +220,15 @@ permInverse :: Perm is              -> r  permInverse = \perm k ->    genPerm perm $ \(invperm :: Perm is') -> -    let sn = permLengthSNat invperm -    in case (provePerm1 (Proxy @is') sn invperm, provePerm2 (SNat @0) sn invperm) of -         (Just Refl, Just Refl) -> -           k invperm -             (\ssh -> case provePermInverse perm invperm ssh of -                        Just eq -> eq -                        Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm -                                           ++ " ; invperm = " ++ show invperm) -         _ -> error $ "permInverse: did not generate permutation? perm = " ++ show perm -                      ++ " ; invperm = " ++ show invperm +    fromMaybe +      (error $ "permInverse: did not generate permutation? perm = " ++ show perm +               ++ " ; invperm = " ++ show invperm) +      (permCheckPermutation invperm +        (k invperm +           (\ssh -> case provePermInverse perm invperm ssh of +                      Just eq -> eq +                      Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm +                                             ++ " ; invperm = " ++ show invperm)))    where      genPerm :: Perm is -> (forall is'. Perm is' -> r) -> r      genPerm perm = @@ -187,46 +239,6 @@ permInverse = \perm k ->          toHList [] k = k PNil          toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l) -    lemElemCount :: (0 <= n, Compare n m ~ LT) => proxy n -> proxy m -> Elem n (Count 0 m) :~: True -    lemElemCount _ _ = unsafeCoerceRefl - -    lemCount :: (OrdCond (Compare i n) True False True ~ True) => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n -    lemCount _ _ = unsafeCoerceRefl - -    lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True -    lemElem _ _ = unsafeCoerceRefl - -    provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is' -               -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True) -    provePerm1 _ _ PNil = Just (Refl) -    provePerm1 p rtop@SNat (PCons sn@SNat perm) -      | Just Refl <- provePerm1 p rtop perm -      = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of -          (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl -          (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl -          _ -> Nothing -      | otherwise -      = Nothing - -    provePerm2 :: SNat i -> SNat n -> Perm is' -               -> Maybe (AllElem' (Count i n) is' :~: True) -    provePerm2 = \i@(SNat :: SNat i) n@SNat perm -> -      case cmpNat i n of -        EQI -> Just Refl -        LTI | Refl <- lemCount i n -            , Just Refl <- provePerm2 (SNat @(i + 1)) n perm -            -> checkElem i perm -            | otherwise -> Nothing -        GTI -> error "unreachable" -      where -        checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True) -        checkElem _ PNil = Nothing -        checkElem i@SNat (PCons k@SNat perm :: Perm is') = -          case sameNat i k of -            Just Refl -> Just Refl -            Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl -                    | otherwise -> Nothing -      provePermInverse :: Perm is -> Perm is' -> StaticShX sh                       -> Maybe (Permute is' (Permute is sh) :~: sh)      provePermInverse perm perminv ssh = | 
