{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Mixed.Permutation where import Data.Coerce (coerce) import Data.Functor.Const import Data.List (sort) import Data.Maybe (fromMaybe) import Data.Proxy import Data.Type.Bool import Data.Type.Equality import Data.Type.Ord import GHC.TypeError import GHC.TypeLits import GHC.TypeNats qualified as TN import Data.Array.Mixed.Shape import Data.Array.Mixed.Types -- * Permutations -- | A "backward" permutation of a dimension list. The operation on the -- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute' -- for code that implements this. data Perm list where PNil :: Perm '[] PCons :: SNat a -> Perm l -> Perm (a : l) infixr 5 `PCons` deriving instance Show (Perm list) deriving instance Eq (Perm list) permLengthSNat :: Perm list -> SNat (Rank list) permLengthSNat PNil = SNat permLengthSNat (_ `PCons` l) | SNat <- permLengthSNat l = SNat permFromList :: [Int] -> (forall list. Perm list -> r) -> r permFromList [] k = k PNil permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case Just sn -> permFromList xs $ \list -> k (sn `PCons` list) Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x permToList :: Perm list -> [Natural] permToList PNil = mempty permToList (x `PCons` l) = TN.fromSNat x : permToList l permToList' :: Perm list -> [Int] permToList' = map fromIntegral . permToList -- | When called as @permCheckPermutation p k@, if @p@ is a permutation of -- @[0 .. 'length' ('permToList' p) - 1]@, @Just k@ is returned. If it isn't, -- then @Nothing@ is returned. permCheckPermutation :: forall r list. Perm list -> (IsPermutation list => r) -> Maybe r permCheckPermutation = \p k -> let n = permLengthSNat p in case (provePerm1 (Proxy @list) n p, provePerm2 (SNat @0) n p) of (Just Refl, Just Refl) -> Just k _ -> Nothing where lemElemCount :: (0 <= n, Compare n m ~ LT) => proxy n -> proxy m -> Elem n (Count 0 m) :~: True lemElemCount _ _ = unsafeCoerceRefl lemCount :: (OrdCond (Compare i n) True False True ~ True) => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n lemCount _ _ = unsafeCoerceRefl lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True lemElem _ _ = unsafeCoerceRefl provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is' -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True) provePerm1 _ _ PNil = Just (Refl) provePerm1 p rtop@SNat (PCons sn@SNat perm) | Just Refl <- provePerm1 p rtop perm = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl _ -> Nothing | otherwise = Nothing provePerm2 :: SNat i -> SNat n -> Perm is' -> Maybe (AllElem' (Count i n) is' :~: True) provePerm2 = \i@(SNat :: SNat i) n@SNat perm -> case cmpNat i n of EQI -> Just Refl LTI | Refl <- lemCount i n , Just Refl <- provePerm2 (SNat @(i + 1)) n perm -> checkElem i perm | otherwise -> Nothing GTI -> error "unreachable" where checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True) checkElem _ PNil = Nothing checkElem i@SNat (PCons k@SNat perm :: Perm is') = case sameNat i k of Just Refl -> Just Refl Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl | otherwise -> Nothing -- | Utility class for generating permutations from type class information. class KnownPerm l where makePerm :: Perm l instance KnownPerm '[] where makePerm = PNil instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm -- | Untyped permutations for ranked arrays type PermR = [Int] -- ** Applying permutations type family Elem x l where Elem x '[] = 'False Elem x (x : _) = 'True Elem x (_ : ys) = Elem x ys type family AllElem' as bs where AllElem' '[] bs = 'True AllElem' (a : as) bs = Elem a bs && AllElem' as bs type AllElem as bs = Assert (AllElem' as bs) (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs)) type family Count i n where Count n n = '[] Count i n = i : Count (i + 1) n type IsPermutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as) type family Index i sh where Index 0 (n : sh) = n Index i (_ : sh) = Index (i - 1) sh type family Permute is sh where Permute '[] sh = '[] Permute (i : is) sh = Index i sh : Permute is sh type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh type family TakeLen ref l where TakeLen '[] l = '[] TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs type family DropLen ref l where DropLen '[] l = l DropLen (_ : ref) (_ : xs) = DropLen ref xs listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f listxTakeLen PNil _ = ZX listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh listxTakeLen (_ `PCons` _) ZX = error "IsPermutation longer than shape" listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f listxDropLen PNil sh = sh listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh listxDropLen (_ `PCons` _) ZX = error "IsPermutation longer than shape" listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f listxPermute PNil _ = ZX listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) = listxIndex (Proxy @is') (Proxy @sh) i sh ::% listxPermute is sh listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh) listxIndex _ _ SZ (n ::% _) = n listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') = listxIndex p pT i sh listxIndex _ _ _ ZX = error "Index into empty shape" listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i ixxPermutePrefix = coerce (listxPermutePrefix @(Const i)) ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh) ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat)) ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh) ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) ssxPermute = coerce (listxPermute @(SMayNat () SNat)) ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh) ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2) ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) -- * Operations on permutations permInverse :: Perm is -> (forall is'. IsPermutation is' => Perm is' -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh) -> r) -> r permInverse = \perm k -> genPerm perm $ \(invperm :: Perm is') -> fromMaybe (error $ "permInverse: did not generate permutation? perm = " ++ show perm ++ " ; invperm = " ++ show invperm) (permCheckPermutation invperm (k invperm (\ssh -> case provePermInverse perm invperm ssh of Just eq -> eq Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm ++ " ; invperm = " ++ show invperm))) where genPerm :: Perm is -> (forall is'. Perm is' -> r) -> r genPerm perm = let permList = permToList' perm in toHList $ map snd (sort (zip permList [0..])) where toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r toHList [] k = k PNil toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l) provePermInverse :: Perm is -> Perm is' -> StaticShX sh -> Maybe (Permute is' (Permute is sh) :~: sh) provePermInverse perm perminv ssh = ssxGeq (ssxPermute perminv (ssxPermute perm ssh)) ssh type family MapSucc is where MapSucc '[] = '[] MapSucc (i : is) = i + 1 : MapSucc is permShift1 :: Perm l -> Perm (0 : MapSucc l) permShift1 = (SNat @0 `PCons`) . permMapSucc where permMapSucc :: Perm l -> Perm (MapSucc l) permMapSucc PNil = PNil permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns -- * Lemmas lemRankPermute :: Proxy sh -> Perm is -> Rank (Permute is sh) :~: Rank is lemRankPermute _ PNil = Refl lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl lemRankDropLen :: forall is sh. (Rank is <= Rank sh) => StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is lemRankDropLen ZKX PNil = Refl lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl lemRankDropLen (_ :!% _) PNil = Refl lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0" lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -> Index (i + 1) (a : l) :~: Index i l lemIndexSucc _ _ _ = unsafeCoerceRefl