aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed/Permutation.hs
blob: 1df0ec752cb273920869b7d6c38ffb5d9d675bd2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed.Permutation where

import Data.Coerce (coerce)
import Data.Functor.Const
import Data.List (sort)
import Data.Proxy
import Data.Type.Bool
import Data.Type.Equality
import Data.Type.Ord
import GHC.TypeError
import GHC.TypeLits
import qualified GHC.TypeNats as TN

import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types


-- * Permutations

-- | A "backward" permutation of a dimension list. The operation on the
-- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute'
-- for code that implements this.
data Perm list where
  PNil :: Perm '[]
  PCons :: SNat a -> Perm l -> Perm (a : l)
infixr 5 `PCons`
deriving instance Show (Perm list)
deriving instance Eq (Perm list)

permLengthSNat :: Perm list -> SNat (Rank list)
permLengthSNat PNil = SNat
permLengthSNat (_ `PCons` l) | SNat <- permLengthSNat l = SNat

permFromList :: [Int] -> (forall list. Perm list -> r) -> r
permFromList [] k = k PNil
permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case
  Just sn -> permFromList xs $ \list -> k (sn `PCons` list)
  Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x

permToList :: Perm list -> [Natural]
permToList PNil = mempty
permToList (x `PCons` l) = TN.fromSNat x : permToList l

permToList' :: Perm list -> [Int]
permToList' = map fromIntegral . permToList


-- ** Applying permutations

type family Elem x l where
  Elem x '[] = 'False
  Elem x (x : _) = 'True
  Elem x (_ : ys) = Elem x ys

type family AllElem' as bs where
  AllElem' '[] bs = 'True
  AllElem' (a : as) bs = Elem a bs && AllElem' as bs

type AllElem as bs = Assert (AllElem' as bs)
  (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs))

type family Count i n where
  Count n n = '[]
  Count i n = i : Count (i + 1) n

type IsPermutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as)

type family Index i sh where
  Index 0 (n : sh) = n
  Index i (_ : sh) = Index (i - 1) sh

type family Permute is sh where
  Permute '[] sh = '[]
  Permute (i : is) sh = Index i sh : Permute is sh

type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh

type family TakeLen ref l where
  TakeLen '[] l = '[]
  TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs

type family DropLen ref l where
  DropLen '[] l = l
  DropLen (_ : ref) (_ : xs) = DropLen ref xs

listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f
listxTakeLen PNil _ = ZX
listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh
listxTakeLen (_ `PCons` _) ZX = error "IsPermutation longer than shape"

listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f
listxDropLen PNil sh = sh
listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh
listxDropLen (_ `PCons` _) ZX = error "IsPermutation longer than shape"

listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f
listxPermute PNil _ = ZX
listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) =
  listxIndex (Proxy @is') (Proxy @sh) i sh (listxPermute is sh)

listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> ListX (Permute is shT) f -> ListX (Index i sh : Permute is shT) f
listxIndex _ _ SZ (n ::% _) rest = n ::% rest
listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) rest
  | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
  = listxIndex p pT i sh rest
listxIndex _ _ _ ZX _ = error "Index into empty shape"

listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f
listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh)

ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i
ixxPermutePrefix = coerce (listxPermutePrefix @(Const i))

ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh)
ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat))

ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh)
ssxDropLen = coerce (listxDropLen @(SMayNat () SNat))

ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh)
ssxPermute = coerce (listxPermute @(SMayNat () SNat))

ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT)
ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2)

ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat))

shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh)
shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat))


-- * Operations on permutations

-- TODO: test this thing more properly
permInverse :: Perm is
            -> (forall is'.
                     IsPermutation is'
                  => Perm is'
                  -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh)
                  -> r)
            -> r
permInverse = \perm k ->
  genPerm perm $ \(invperm :: Perm is') ->
    let sn = permLengthSNat invperm
    in case (provePerm1 (Proxy @is') sn invperm, provePerm2 (SNat @0) sn invperm) of
         (Just Refl, Just Refl) ->
           k invperm
             (\ssh -> case provePermInverse perm invperm ssh of
                        Just eq -> eq
                        Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm
                                           ++ " ; invperm = " ++ show invperm)
         _ -> error $ "permInverse: did not generate permutation? perm = " ++ show perm
                      ++ " ; invperm = " ++ show invperm
  where
    genPerm :: Perm is -> (forall is'. Perm is' -> r) -> r
    genPerm perm =
      let permList = permToList' perm
      in toHList $ map snd (sort (zip permList [0..]))
      where
        toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r
        toHList [] k = k PNil
        toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l)

    lemElemCount :: (0 <= n, Compare n m ~ LT) => proxy n -> proxy m -> Elem n (Count 0 m) :~: True
    lemElemCount _ _ = unsafeCoerceRefl

    lemCount :: (OrdCond (Compare i n) True False True ~ True) => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n
    lemCount _ _ = unsafeCoerceRefl

    lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True
    lemElem _ _ = unsafeCoerceRefl

    provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is'
               -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True)
    provePerm1 _ _ PNil = Just (Refl)
    provePerm1 p rtop@SNat (PCons sn@SNat perm)
      | Just Refl <- provePerm1 p rtop perm
      = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of
          (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
          (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
          _ -> Nothing
      | otherwise
      = Nothing

    provePerm2 :: SNat i -> SNat n -> Perm is'
               -> Maybe (AllElem' (Count i n) is' :~: True)
    provePerm2 = \i@(SNat :: SNat i) n@SNat perm ->
      case cmpNat i n of
        EQI -> Just Refl
        LTI | Refl <- lemCount i n
            , Just Refl <- provePerm2 (SNat @(i + 1)) n perm
            -> checkElem i perm
            | otherwise -> Nothing
        GTI -> error "unreachable"
      where
        checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True)
        checkElem _ PNil = Nothing
        checkElem i@SNat (PCons k@SNat perm :: Perm is') =
          case sameNat i k of
            Just Refl -> Just Refl
            Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl
                    | otherwise -> Nothing

    provePermInverse :: Perm is -> Perm is' -> StaticShX sh
                     -> Maybe (Permute is' (Permute is sh) :~: sh)
    provePermInverse perm perminv ssh =
      ssxGeq (ssxPermute perminv (ssxPermute perm ssh)) ssh

type family MapSucc is where
  MapSucc '[] = '[]
  MapSucc (i : is) = i + 1 : MapSucc is

permShift1 :: Perm l -> Perm (0 : MapSucc l)
permShift1 = (SNat @0 `PCons`) . permMapSucc
  where
    permMapSucc :: Perm l -> Perm (MapSucc l)
    permMapSucc PNil = PNil
    permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns


-- * Lemmas

lemRankPermute :: Proxy sh -> Perm is -> Rank (Permute is sh) :~: Rank is
lemRankPermute _ PNil = Refl
lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl

lemRankDropLen :: forall is sh. (Rank is <= Rank sh)
               => StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is
lemRankDropLen ZKX PNil = Refl
lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl
lemRankDropLen (_ :!% _) PNil = Refl
lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0"

lemIndexSucc :: Proxy i -> Proxy a -> Proxy l
             -> Index (i + 1) (a : l) :~: Index i l
lemIndexSucc _ _ _ = unsafeCoerceRefl