blob: 6ff3bdc9980121fe22121c111b0d8d1aa662a2a5 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
|
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed.Permutation where
import Data.Coerce (coerce)
import Data.Functor.Const
import Data.List (sort)
import Data.Proxy
import Data.Type.Bool
import Data.Type.Equality
import Data.Type.Ord
import GHC.TypeError
import GHC.TypeLits
import GHC.TypeNats qualified as TN
import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
-- * Permutations
-- | A "backward" permutation of a dimension list. The operation on the
-- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute'
-- for code that implements this.
data Perm list where
PNil :: Perm '[]
PCons :: SNat a -> Perm l -> Perm (a : l)
infixr 5 `PCons`
deriving instance Show (Perm list)
deriving instance Eq (Perm list)
permLengthSNat :: Perm list -> SNat (Rank list)
permLengthSNat PNil = SNat
permLengthSNat (_ `PCons` l) | SNat <- permLengthSNat l = SNat
permFromList :: [Int] -> (forall list. Perm list -> r) -> r
permFromList [] k = k PNil
permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case
Just sn -> permFromList xs $ \list -> k (sn `PCons` list)
Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x
permToList :: Perm list -> [Natural]
permToList PNil = mempty
permToList (x `PCons` l) = TN.fromSNat x : permToList l
permToList' :: Perm list -> [Int]
permToList' = map fromIntegral . permToList
-- | Utility class for generating permutations from type class information.
class KnownPerm l where makePerm :: Perm l
instance KnownPerm '[] where makePerm = PNil
instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm
-- ** Applying permutations
type family Elem x l where
Elem x '[] = 'False
Elem x (x : _) = 'True
Elem x (_ : ys) = Elem x ys
type family AllElem' as bs where
AllElem' '[] bs = 'True
AllElem' (a : as) bs = Elem a bs && AllElem' as bs
type AllElem as bs = Assert (AllElem' as bs)
(TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs))
type family Count i n where
Count n n = '[]
Count i n = i : Count (i + 1) n
type IsPermutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as)
type family Index i sh where
Index 0 (n : sh) = n
Index i (_ : sh) = Index (i - 1) sh
type family Permute is sh where
Permute '[] sh = '[]
Permute (i : is) sh = Index i sh : Permute is sh
type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh
type family TakeLen ref l where
TakeLen '[] l = '[]
TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs
type family DropLen ref l where
DropLen '[] l = l
DropLen (_ : ref) (_ : xs) = DropLen ref xs
listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f
listxTakeLen PNil _ = ZX
listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh
listxTakeLen (_ `PCons` _) ZX = error "IsPermutation longer than shape"
listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f
listxDropLen PNil sh = sh
listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh
listxDropLen (_ `PCons` _) ZX = error "IsPermutation longer than shape"
listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f
listxPermute PNil _ = ZX
listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) =
listxIndex (Proxy @is') (Proxy @sh) i sh ::% listxPermute is sh
listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh)
listxIndex _ _ SZ (n ::% _) = n
listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f))
| Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
= listxIndex p pT i sh
listxIndex _ _ _ ZX = error "Index into empty shape"
listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f
listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh)
ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i
ixxPermutePrefix = coerce (listxPermutePrefix @(Const i))
ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh)
ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat))
ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh)
ssxDropLen = coerce (listxDropLen @(SMayNat () SNat))
ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh)
ssxPermute = coerce (listxPermute @(SMayNat () SNat))
ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh)
ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2)
ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat))
shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh)
shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat))
-- * Operations on permutations
-- TODO: test this thing more properly
permInverse :: Perm is
-> (forall is'.
IsPermutation is'
=> Perm is'
-> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh)
-> r)
-> r
permInverse = \perm k ->
genPerm perm $ \(invperm :: Perm is') ->
let sn = permLengthSNat invperm
in case (provePerm1 (Proxy @is') sn invperm, provePerm2 (SNat @0) sn invperm) of
(Just Refl, Just Refl) ->
k invperm
(\ssh -> case provePermInverse perm invperm ssh of
Just eq -> eq
Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm
++ " ; invperm = " ++ show invperm)
_ -> error $ "permInverse: did not generate permutation? perm = " ++ show perm
++ " ; invperm = " ++ show invperm
where
genPerm :: Perm is -> (forall is'. Perm is' -> r) -> r
genPerm perm =
let permList = permToList' perm
in toHList $ map snd (sort (zip permList [0..]))
where
toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r
toHList [] k = k PNil
toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l)
lemElemCount :: (0 <= n, Compare n m ~ LT) => proxy n -> proxy m -> Elem n (Count 0 m) :~: True
lemElemCount _ _ = unsafeCoerceRefl
lemCount :: (OrdCond (Compare i n) True False True ~ True) => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n
lemCount _ _ = unsafeCoerceRefl
lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True
lemElem _ _ = unsafeCoerceRefl
provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is'
-> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True)
provePerm1 _ _ PNil = Just (Refl)
provePerm1 p rtop@SNat (PCons sn@SNat perm)
| Just Refl <- provePerm1 p rtop perm
= case (cmpNat (SNat @0) sn, cmpNat sn rtop) of
(LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
(EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
_ -> Nothing
| otherwise
= Nothing
provePerm2 :: SNat i -> SNat n -> Perm is'
-> Maybe (AllElem' (Count i n) is' :~: True)
provePerm2 = \i@(SNat :: SNat i) n@SNat perm ->
case cmpNat i n of
EQI -> Just Refl
LTI | Refl <- lemCount i n
, Just Refl <- provePerm2 (SNat @(i + 1)) n perm
-> checkElem i perm
| otherwise -> Nothing
GTI -> error "unreachable"
where
checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True)
checkElem _ PNil = Nothing
checkElem i@SNat (PCons k@SNat perm :: Perm is') =
case sameNat i k of
Just Refl -> Just Refl
Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl
| otherwise -> Nothing
provePermInverse :: Perm is -> Perm is' -> StaticShX sh
-> Maybe (Permute is' (Permute is sh) :~: sh)
provePermInverse perm perminv ssh =
ssxGeq (ssxPermute perminv (ssxPermute perm ssh)) ssh
type family MapSucc is where
MapSucc '[] = '[]
MapSucc (i : is) = i + 1 : MapSucc is
permShift1 :: Perm l -> Perm (0 : MapSucc l)
permShift1 = (SNat @0 `PCons`) . permMapSucc
where
permMapSucc :: Perm l -> Perm (MapSucc l)
permMapSucc PNil = PNil
permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns
-- * Lemmas
lemRankPermute :: Proxy sh -> Perm is -> Rank (Permute is sh) :~: Rank is
lemRankPermute _ PNil = Refl
lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl
lemRankDropLen :: forall is sh. (Rank is <= Rank sh)
=> StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is
lemRankDropLen ZKX PNil = Refl
lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl
lemRankDropLen (_ :!% _) PNil = Refl
lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0"
lemIndexSucc :: Proxy i -> Proxy a -> Proxy l
-> Index (i + 1) (a : l) :~: Index i l
lemIndexSucc _ _ _ = unsafeCoerceRefl
|