aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Mixed.hs64
1 files changed, 57 insertions, 7 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 065756d..24a8482 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -39,6 +39,7 @@ import Data.Monoid (Sum(..))
import Data.Proxy
import Data.Type.Bool
import Data.Type.Equality
+import Data.Type.Ord
import qualified Data.Vector.Storable as VS
import Foreign.Storable (Storable)
import GHC.Generics (Generic)
@@ -715,6 +716,10 @@ foldHList :: Monoid m => (forall a. f a -> m) -> HList f list -> m
foldHList _ HNil = mempty
foldHList f (x `HCons` l) = f x <> foldHList f l
+snatLengthHList :: HList f list -> SNat (Rank list)
+snatLengthHList HNil = SNat
+snatLengthHList (_ `HCons` l) | SNat <- snatLengthHList l = SNat
+
type family TakeLen ref l where
TakeLen '[] l = '[]
TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs
@@ -782,17 +787,23 @@ shPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat))
-- TODO: test this thing more properly
invertPermutation :: HList SNat is
-> (forall is'.
- HList SNat is'
+ Permutation is'
+ => HList SNat is'
-> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh)
-> r)
-> r
invertPermutation = \perm k ->
- genPerm perm $ \invperm ->
- k invperm
- (\ssh -> case provePermInverse perm invperm ssh of
- Just eq -> eq
- Nothing -> error $ "invertPermutation: did not generate inverse? perm = " ++ show perm
- ++ " ; invperm = " ++ show invperm)
+ genPerm perm $ \(invperm :: HList SNat is') ->
+ let sn = snatLengthHList invperm
+ in case (provePerm1 (Proxy @is') sn invperm, provePerm2 (SNat @0) sn invperm) of
+ (Just Refl, Just Refl) ->
+ k invperm
+ (\ssh -> case provePermInverse perm invperm ssh of
+ Just eq -> eq
+ Nothing -> error $ "invertPermutation: did not generate inverse? perm = " ++ show perm
+ ++ " ; invperm = " ++ show invperm)
+ _ -> error $ "invertPermutation: did not generate permutation? perm = " ++ show perm
+ ++ " ; invperm = " ++ show invperm
where
genPerm :: HList SNat is -> (forall is'. HList SNat is' -> r) -> r
genPerm perm =
@@ -803,6 +814,45 @@ invertPermutation = \perm k ->
toHList [] k = k HNil
toHList (n : ns) k = toHList ns $ \l -> TypeNats.withSomeSNat n $ \sn -> k (HCons sn l)
+ lemElemCount :: (0 <= n, Compare n m ~ LT) => proxy n -> proxy m -> Elem n (Count 0 m) :~: True
+ lemElemCount _ _ = unsafeCoerce Refl
+
+ lemCount :: (OrdCond (Compare i n) True False True ~ True) => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n
+ lemCount _ _ = unsafeCoerce Refl
+
+ lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True
+ lemElem _ _ = unsafeCoerce Refl
+
+ provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> HList SNat is'
+ -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True)
+ provePerm1 _ _ HNil = Just (Refl)
+ provePerm1 p rtop@SNat (HCons sn@SNat perm)
+ | Just Refl <- provePerm1 p rtop perm
+ = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of
+ (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
+ (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
+ _ -> Nothing
+ | otherwise
+ = Nothing
+
+ provePerm2 :: SNat i -> SNat n -> HList SNat is' -> Maybe (AllElem' (Count i n) is' :~: True)
+ provePerm2 = \i@(SNat :: SNat i) n@SNat perm ->
+ case cmpNat i n of
+ EQI -> Just Refl
+ LTI | Refl <- lemCount i n
+ , Just Refl <- provePerm2 (SNat @(i + 1)) n perm
+ -> checkElem i perm
+ | otherwise -> Nothing
+ GTI -> error "unreachable"
+ where
+ checkElem :: SNat i -> HList SNat is' -> Maybe (Elem i is' :~: True)
+ checkElem _ HNil = Nothing
+ checkElem i@SNat (HCons k@SNat perm :: HList SNat is') =
+ case sameNat i k of
+ Just Refl -> Just Refl
+ Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl
+ | otherwise -> Nothing
+
provePermInverse :: HList SNat is -> HList SNat is' -> StaticShX sh -> Maybe (Permute is' (Permute is sh) :~: sh)
provePermInverse perm perminv ssh = geqStaticShX (ssxPermute perminv (ssxPermute perm ssh)) ssh