aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed/Permutation.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed/Permutation.hs')
-rw-r--r--src/Data/Array/Mixed/Permutation.hs252
1 files changed, 252 insertions, 0 deletions
diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs
new file mode 100644
index 0000000..2710018
--- /dev/null
+++ b/src/Data/Array/Mixed/Permutation.hs
@@ -0,0 +1,252 @@
+{-# LANGUAGE ConstraintKinds #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Mixed.Permutation where
+
+import Data.Coerce (coerce)
+import Data.Functor.Const
+import Data.List (sort)
+import Data.Proxy
+import Data.Type.Bool
+import Data.Type.Equality
+import Data.Type.Ord
+import GHC.TypeError
+import GHC.TypeLits
+import qualified GHC.TypeNats as TN
+
+import Data.Array.Mixed.Shape
+import Data.Array.Mixed.Types
+
+
+-- * Permutations
+
+-- | A "backward" permutation of a dimension list. The operation on the
+-- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute'
+-- for code that implements this.
+data Perm list where
+ PNil :: Perm '[]
+ PCons :: SNat a -> Perm l -> Perm (a : l)
+infixr 5 `PCons`
+deriving instance Show (Perm list)
+deriving instance Eq (Perm list)
+
+permLengthSNat :: Perm list -> SNat (Rank list)
+permLengthSNat PNil = SNat
+permLengthSNat (_ `PCons` l) | SNat <- permLengthSNat l = SNat
+
+permFromList :: [Int] -> (forall list. Perm list -> r) -> r
+permFromList [] k = k PNil
+permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case
+ Just sn -> permFromList xs $ \list -> k (sn `PCons` list)
+ Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x
+
+permToList :: Perm list -> [Natural]
+permToList PNil = mempty
+permToList (x `PCons` l) = TN.fromSNat x : permToList l
+
+permToList' :: Perm list -> [Int]
+permToList' = map fromIntegral . permToList
+
+
+-- ** Applying permutations
+
+type family Elem x l where
+ Elem x '[] = 'False
+ Elem x (x : _) = 'True
+ Elem x (_ : ys) = Elem x ys
+
+type family AllElem' as bs where
+ AllElem' '[] bs = 'True
+ AllElem' (a : as) bs = Elem a bs && AllElem' as bs
+
+type AllElem as bs = Assert (AllElem' as bs)
+ (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs))
+
+type family Count i n where
+ Count n n = '[]
+ Count i n = i : Count (i + 1) n
+
+type IsPermutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as)
+
+type family Index i sh where
+ Index 0 (n : sh) = n
+ Index i (_ : sh) = Index (i - 1) sh
+
+type family Permute is sh where
+ Permute '[] sh = '[]
+ Permute (i : is) sh = Index i sh : Permute is sh
+
+type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh
+
+type family TakeLen ref l where
+ TakeLen '[] l = '[]
+ TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs
+
+type family DropLen ref l where
+ DropLen '[] l = l
+ DropLen (_ : ref) (_ : xs) = DropLen ref xs
+
+listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f
+listxTakeLen PNil _ = ZX
+listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh
+listxTakeLen (_ `PCons` _) ZX = error "IsPermutation longer than shape"
+
+listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f
+listxDropLen PNil sh = sh
+listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh
+listxDropLen (_ `PCons` _) ZX = error "IsPermutation longer than shape"
+
+listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f
+listxPermute PNil _ = ZX
+listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) =
+ listxIndex (Proxy @is') (Proxy @sh) i sh (listxPermute is sh)
+
+listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> ListX (Permute is shT) f -> ListX (Index i sh : Permute is shT) f
+listxIndex _ _ SZ (n ::% _) rest = n ::% rest
+listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) rest
+ | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = listxIndex p pT i sh rest
+listxIndex _ _ _ ZX _ = error "Index into empty shape"
+
+listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f
+listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh)
+
+ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i
+ixxPermutePrefix = coerce (listxPermutePrefix @(Const i))
+
+ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh)
+ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat))
+
+ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh)
+ssxDropLen = coerce (listxDropLen @(SMayNat () SNat))
+
+ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh)
+ssxPermute = coerce (listxPermute @(SMayNat () SNat))
+
+ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT)
+ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2)
+
+ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
+ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat))
+
+shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh)
+shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat))
+
+
+-- * Operations on permutations
+
+-- TODO: test this thing more properly
+permInverse :: Perm is
+ -> (forall is'.
+ IsPermutation is'
+ => Perm is'
+ -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh)
+ -> r)
+ -> r
+permInverse = \perm k ->
+ genPerm perm $ \(invperm :: Perm is') ->
+ let sn = permLengthSNat invperm
+ in case (provePerm1 (Proxy @is') sn invperm, provePerm2 (SNat @0) sn invperm) of
+ (Just Refl, Just Refl) ->
+ k invperm
+ (\ssh -> case provePermInverse perm invperm ssh of
+ Just eq -> eq
+ Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm
+ ++ " ; invperm = " ++ show invperm)
+ _ -> error $ "permInverse: did not generate permutation? perm = " ++ show perm
+ ++ " ; invperm = " ++ show invperm
+ where
+ genPerm :: Perm is -> (forall is'. Perm is' -> r) -> r
+ genPerm perm =
+ let permList = permToList' perm
+ in toHList $ map snd (sort (zip permList [0..]))
+ where
+ toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r
+ toHList [] k = k PNil
+ toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l)
+
+ lemElemCount :: (0 <= n, Compare n m ~ LT) => proxy n -> proxy m -> Elem n (Count 0 m) :~: True
+ lemElemCount _ _ = unsafeCoerceRefl
+
+ lemCount :: (OrdCond (Compare i n) True False True ~ True) => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n
+ lemCount _ _ = unsafeCoerceRefl
+
+ lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True
+ lemElem _ _ = unsafeCoerceRefl
+
+ provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is'
+ -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True)
+ provePerm1 _ _ PNil = Just (Refl)
+ provePerm1 p rtop@SNat (PCons sn@SNat perm)
+ | Just Refl <- provePerm1 p rtop perm
+ = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of
+ (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
+ (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
+ _ -> Nothing
+ | otherwise
+ = Nothing
+
+ provePerm2 :: SNat i -> SNat n -> Perm is'
+ -> Maybe (AllElem' (Count i n) is' :~: True)
+ provePerm2 = \i@(SNat :: SNat i) n@SNat perm ->
+ case cmpNat i n of
+ EQI -> Just Refl
+ LTI | Refl <- lemCount i n
+ , Just Refl <- provePerm2 (SNat @(i + 1)) n perm
+ -> checkElem i perm
+ | otherwise -> Nothing
+ GTI -> error "unreachable"
+ where
+ checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True)
+ checkElem _ PNil = Nothing
+ checkElem i@SNat (PCons k@SNat perm :: Perm is') =
+ case sameNat i k of
+ Just Refl -> Just Refl
+ Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl
+ | otherwise -> Nothing
+
+ provePermInverse :: Perm is -> Perm is' -> StaticShX sh
+ -> Maybe (Permute is' (Permute is sh) :~: sh)
+ provePermInverse perm perminv ssh =
+ ssxGeq (ssxPermute perminv (ssxPermute perm ssh)) ssh
+
+type family MapSucc is where
+ MapSucc '[] = '[]
+ MapSucc (i : is) = i + 1 : MapSucc is
+
+permShift1 :: Perm l -> Perm (0 : MapSucc l)
+permShift1 = (SNat @0 `PCons`) . permMapSucc
+ where
+ permMapSucc :: Perm l -> Perm (MapSucc l)
+ permMapSucc PNil = PNil
+ permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns
+
+
+-- * Lemmas
+
+lemRankPermute :: Proxy sh -> Perm is -> Rank (Permute is sh) :~: Rank is
+lemRankPermute _ PNil = Refl
+lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl
+
+lemRankDropLen :: forall is sh. (Rank is <= Rank sh)
+ => StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is
+lemRankDropLen ZKX PNil = Refl
+lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl
+lemRankDropLen (_ :!% _) PNil = Refl
+lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0"
+
+lemIndexSucc :: Proxy i -> Proxy a -> Proxy l
+ -> Index (i + 1) (a : l) :~: Index i l
+lemIndexSucc _ _ _ = unsafeCoerceRefl