{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module Data.Array.Mixed.Permutation where import Data.Coerce (coerce) import Data.Functor.Const import Data.List (sort) import Data.Maybe (fromMaybe) import Data.Proxy import Data.Type.Bool import Data.Type.Equality import GHC.TypeError import Numeric.Natural import Data.Array.Mixed.Shape import Data.Array.Mixed.Types import Data.SNat.Peano -- * Permutations -- | A "backward" permutation of a dimension list. The operation on the -- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute' -- for code that implements this. data Perm list where PNil :: Perm '[] PCons :: SNat a -> Perm l -> Perm (a : l) infixr 5 `PCons` deriving instance Show (Perm list) deriving instance Eq (Perm list) permRank :: Perm list -> SNat (Rank list) permRank PNil = SZ permRank (_ `PCons` l) = SS (permRank l) permFromList :: [Int] -> (forall list. Perm list -> r) -> r permFromList [] k = k PNil permFromList (x : xs) k = withSomeSNat' x $ \sn -> permFromList xs $ \list -> k (sn `PCons` list) permToList :: Perm list -> [Natural] permToList PNil = mempty permToList (x `PCons` l) = fromSNat x : permToList l permToList' :: Perm list -> [Int] permToList' = map fromIntegral . permToList -- | When called as @permCheckPermutation p k@, if @p@ is a permutation of -- @[0 .. 'length' ('permToList' p) - 1]@, @Just k@ is returned. If it isn't, -- then @Nothing@ is returned. permCheckPermutation :: forall r list. Perm list -> (IsPermutation list => r) -> Maybe r permCheckPermutation = \p k -> let n = permRank p in case (provePerm1 (Proxy @list) n p, provePerm2 SZ n p) of (Just Refl, Just Refl) -> Just k _ -> Nothing where lemElemCount :: (Z <= n, n < m) => proxy n -> proxy m -> Elem n (Count Z m) :~: True lemElemCount _ _ = unsafeCoerceRefl lemCount :: i < n => proxy i -> proxy n -> Count i n :~: i : Count (S i) n lemCount _ _ = unsafeCoerceRefl lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True lemElem _ _ = unsafeCoerceRefl provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is' -> Maybe (AllElem' is' (Count Z (Rank isTop)) :~: True) provePerm1 _ _ PNil = Just (Refl) provePerm1 p rtop (PCons sn perm) | Just Refl <- provePerm1 p rtop perm = case (snatCompare SZ sn, snatCompare sn rtop) of (SLT, SLT) | Refl <- lemElemCount sn rtop -> Just Refl (SEQ, SLT) | Refl <- lemElemCount sn rtop -> Just Refl _ -> Nothing | otherwise = Nothing provePerm2 :: SNat i -> SNat n -> Perm is' -> Maybe (AllElem' (Count i n) is' :~: True) provePerm2 = \i n perm -> case snatCompare i n of SEQ -> Just Refl SLT | Refl <- lemCount i n , Just Refl <- provePerm2 (SS i) n perm -> checkElem i perm | otherwise -> Nothing SGT -> error "unreachable" where checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True) checkElem _ PNil = Nothing checkElem i (PCons k perm :: Perm is') = case testEquality i k of Just Refl -> Just Refl Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl | otherwise -> Nothing -- | Utility class for generating permutations from type class information. class KnownPerm l where makePerm :: Perm l instance KnownPerm '[] where makePerm = PNil instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = knownNat `PCons` makePerm -- | Untyped permutations for ranked arrays type PermR = [Int] -- ** Applying permutations type family Elem x l where Elem x '[] = 'False Elem x (x : _) = 'True Elem x (_ : ys) = Elem x ys type family AllElem' as bs where AllElem' '[] bs = 'True AllElem' (a : as) bs = Elem a bs && AllElem' as bs type AllElem as bs = Assert (AllElem' as bs) (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs)) type family Count i n where Count n n = '[] Count i n = i : Count (S i) n type IsPermutation as = (AllElem as (Count Z (Rank as)), AllElem (Count Z (Rank as)) as) type family Index i sh where Index Z (n : sh) = n Index (S i) (_ : sh) = Index i sh type family Permute is sh where Permute '[] sh = '[] Permute (i : is) sh = Index i sh : Permute is sh type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh type family TakeLen ref l where TakeLen '[] l = '[] TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs type family DropLen ref l where DropLen '[] l = l DropLen (_ : ref) (_ : xs) = DropLen ref xs listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f listxTakeLen PNil _ = ZX listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh listxTakeLen (_ `PCons` _) ZX = error "IsPermutation longer than shape" listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f listxDropLen PNil sh = sh listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh listxDropLen (_ `PCons` _) ZX = error "IsPermutation longer than shape" listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f listxPermute PNil _ = ZX listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) = listxIndex (Proxy @is') (Proxy @sh) i sh ::% listxPermute is sh listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh) listxIndex _ _ SZ (n ::% _) = n listxIndex p pT (SS i) (_ ::% sh) = listxIndex p pT i sh listxIndex _ _ _ ZX = error "Index into empty shape" listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i ixxPermutePrefix = coerce (listxPermutePrefix @(Const i)) ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh) ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat)) ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh) ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) ssxPermute = coerce (listxPermute @(SMayNat () SNat)) ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh) ssxIndex p1 p2 i = coerce (listxIndex @(SMayNat () SNat) p1 p2 i) ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) -- * Operations on permutations permInverse :: Perm is -> (forall is'. IsPermutation is' => Perm is' -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh) -> r) -> r permInverse = \perm k -> genPerm perm $ \(invperm :: Perm is') -> fromMaybe (error $ "permInverse: did not generate permutation? perm = " ++ show perm ++ " ; invperm = " ++ show invperm) (permCheckPermutation invperm (k invperm (\ssh -> case provePermInverse perm invperm ssh of Just eq -> eq Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm ++ " ; invperm = " ++ show invperm))) where genPerm :: Perm is -> (forall is'. Perm is' -> r) -> r genPerm perm = let permList = permToList' perm in toHList $ map snd (sort (zip permList [0..])) where toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r toHList [] k = k PNil toHList (n : ns) k = toHList ns $ \l -> withSomeSNat n $ \sn -> k (PCons sn l) provePermInverse :: Perm is -> Perm is' -> StaticShX sh -> Maybe (Permute is' (Permute is sh) :~: sh) provePermInverse perm perminv ssh = testEquality (ssxPermute perminv (ssxPermute perm ssh)) ssh type family MapSucc is where MapSucc '[] = '[] MapSucc (i : is) = S i : MapSucc is permShift1 :: Perm l -> Perm (Z : MapSucc l) permShift1 = (SZ `PCons`) . permMapSucc where permMapSucc :: Perm l -> Perm (MapSucc l) permMapSucc PNil = PNil permMapSucc (i `PCons` ns) = SS i `PCons` permMapSucc ns -- * Lemmas lemRankPermute :: Proxy sh -> Perm is -> Rank (Permute is sh) :~: Rank is lemRankPermute _ PNil = Refl lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl lemRankDropLen :: forall is sh. (Rank is <= Rank sh) => StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is lemRankDropLen ZKX PNil = Refl lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl lemRankDropLen (_ :!% _) PNil = Refl