diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-30 11:58:40 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-30 11:58:40 +0200 |
commit | a65306ba5d80891b20ac86fa3a3242f9497751e6 (patch) | |
tree | 834af370556a46bbeca807a92c31bef098b47a89 /src/Data/Array/Mixed/Permutation.hs | |
parent | d8e2fcf4ea979fe272db48fc2889f4c2636c50d7 (diff) |
Refactor Mixed (modules, regular function names)
Diffstat (limited to 'src/Data/Array/Mixed/Permutation.hs')
-rw-r--r-- | src/Data/Array/Mixed/Permutation.hs | 252 |
1 files changed, 252 insertions, 0 deletions
diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs new file mode 100644 index 0000000..2710018 --- /dev/null +++ b/src/Data/Array/Mixed/Permutation.hs @@ -0,0 +1,252 @@ +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Mixed.Permutation where + +import Data.Coerce (coerce) +import Data.Functor.Const +import Data.List (sort) +import Data.Proxy +import Data.Type.Bool +import Data.Type.Equality +import Data.Type.Ord +import GHC.TypeError +import GHC.TypeLits +import qualified GHC.TypeNats as TN + +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types + + +-- * Permutations + +-- | A "backward" permutation of a dimension list. The operation on the +-- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute' +-- for code that implements this. +data Perm list where + PNil :: Perm '[] + PCons :: SNat a -> Perm l -> Perm (a : l) +infixr 5 `PCons` +deriving instance Show (Perm list) +deriving instance Eq (Perm list) + +permLengthSNat :: Perm list -> SNat (Rank list) +permLengthSNat PNil = SNat +permLengthSNat (_ `PCons` l) | SNat <- permLengthSNat l = SNat + +permFromList :: [Int] -> (forall list. Perm list -> r) -> r +permFromList [] k = k PNil +permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case + Just sn -> permFromList xs $ \list -> k (sn `PCons` list) + Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x + +permToList :: Perm list -> [Natural] +permToList PNil = mempty +permToList (x `PCons` l) = TN.fromSNat x : permToList l + +permToList' :: Perm list -> [Int] +permToList' = map fromIntegral . permToList + + +-- ** Applying permutations + +type family Elem x l where + Elem x '[] = 'False + Elem x (x : _) = 'True + Elem x (_ : ys) = Elem x ys + +type family AllElem' as bs where + AllElem' '[] bs = 'True + AllElem' (a : as) bs = Elem a bs && AllElem' as bs + +type AllElem as bs = Assert (AllElem' as bs) + (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs)) + +type family Count i n where + Count n n = '[] + Count i n = i : Count (i + 1) n + +type IsPermutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as) + +type family Index i sh where + Index 0 (n : sh) = n + Index i (_ : sh) = Index (i - 1) sh + +type family Permute is sh where + Permute '[] sh = '[] + Permute (i : is) sh = Index i sh : Permute is sh + +type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh + +type family TakeLen ref l where + TakeLen '[] l = '[] + TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs + +type family DropLen ref l where + DropLen '[] l = l + DropLen (_ : ref) (_ : xs) = DropLen ref xs + +listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f +listxTakeLen PNil _ = ZX +listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh +listxTakeLen (_ `PCons` _) ZX = error "IsPermutation longer than shape" + +listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f +listxDropLen PNil sh = sh +listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh +listxDropLen (_ `PCons` _) ZX = error "IsPermutation longer than shape" + +listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f +listxPermute PNil _ = ZX +listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) = + listxIndex (Proxy @is') (Proxy @sh) i sh (listxPermute is sh) + +listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> ListX (Permute is shT) f -> ListX (Index i sh : Permute is shT) f +listxIndex _ _ SZ (n ::% _) rest = n ::% rest +listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) rest + | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = listxIndex p pT i sh rest +listxIndex _ _ _ ZX _ = error "Index into empty shape" + +listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f +listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) + +ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i +ixxPermutePrefix = coerce (listxPermutePrefix @(Const i)) + +ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh) +ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat)) + +ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh) +ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) + +ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) +ssxPermute = coerce (listxPermute @(SMayNat () SNat)) + +ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT) +ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2) + +ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) +ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) + +shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) +shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) + + +-- * Operations on permutations + +-- TODO: test this thing more properly +permInverse :: Perm is + -> (forall is'. + IsPermutation is' + => Perm is' + -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh) + -> r) + -> r +permInverse = \perm k -> + genPerm perm $ \(invperm :: Perm is') -> + let sn = permLengthSNat invperm + in case (provePerm1 (Proxy @is') sn invperm, provePerm2 (SNat @0) sn invperm) of + (Just Refl, Just Refl) -> + k invperm + (\ssh -> case provePermInverse perm invperm ssh of + Just eq -> eq + Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm + ++ " ; invperm = " ++ show invperm) + _ -> error $ "permInverse: did not generate permutation? perm = " ++ show perm + ++ " ; invperm = " ++ show invperm + where + genPerm :: Perm is -> (forall is'. Perm is' -> r) -> r + genPerm perm = + let permList = permToList' perm + in toHList $ map snd (sort (zip permList [0..])) + where + toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r + toHList [] k = k PNil + toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l) + + lemElemCount :: (0 <= n, Compare n m ~ LT) => proxy n -> proxy m -> Elem n (Count 0 m) :~: True + lemElemCount _ _ = unsafeCoerceRefl + + lemCount :: (OrdCond (Compare i n) True False True ~ True) => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n + lemCount _ _ = unsafeCoerceRefl + + lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True + lemElem _ _ = unsafeCoerceRefl + + provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is' + -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True) + provePerm1 _ _ PNil = Just (Refl) + provePerm1 p rtop@SNat (PCons sn@SNat perm) + | Just Refl <- provePerm1 p rtop perm + = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of + (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl + (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl + _ -> Nothing + | otherwise + = Nothing + + provePerm2 :: SNat i -> SNat n -> Perm is' + -> Maybe (AllElem' (Count i n) is' :~: True) + provePerm2 = \i@(SNat :: SNat i) n@SNat perm -> + case cmpNat i n of + EQI -> Just Refl + LTI | Refl <- lemCount i n + , Just Refl <- provePerm2 (SNat @(i + 1)) n perm + -> checkElem i perm + | otherwise -> Nothing + GTI -> error "unreachable" + where + checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True) + checkElem _ PNil = Nothing + checkElem i@SNat (PCons k@SNat perm :: Perm is') = + case sameNat i k of + Just Refl -> Just Refl + Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl + | otherwise -> Nothing + + provePermInverse :: Perm is -> Perm is' -> StaticShX sh + -> Maybe (Permute is' (Permute is sh) :~: sh) + provePermInverse perm perminv ssh = + ssxGeq (ssxPermute perminv (ssxPermute perm ssh)) ssh + +type family MapSucc is where + MapSucc '[] = '[] + MapSucc (i : is) = i + 1 : MapSucc is + +permShift1 :: Perm l -> Perm (0 : MapSucc l) +permShift1 = (SNat @0 `PCons`) . permMapSucc + where + permMapSucc :: Perm l -> Perm (MapSucc l) + permMapSucc PNil = PNil + permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns + + +-- * Lemmas + +lemRankPermute :: Proxy sh -> Perm is -> Rank (Permute is sh) :~: Rank is +lemRankPermute _ PNil = Refl +lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl + +lemRankDropLen :: forall is sh. (Rank is <= Rank sh) + => StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is +lemRankDropLen ZKX PNil = Refl +lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl +lemRankDropLen (_ :!% _) PNil = Refl +lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0" + +lemIndexSucc :: Proxy i -> Proxy a -> Proxy l + -> Index (i + 1) (a : l) :~: Index i l +lemIndexSucc _ _ _ = unsafeCoerceRefl |