1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Nested.Shaped.Shape where
import Control.DeepSeq (NFData(..))
import Data.Array.Mixed.Types
import Data.Array.Shape qualified as O
import Data.Coerce (coerce)
import Data.Foldable qualified as Foldable
import Data.Functor.Const
import Data.Functor.Product qualified as Fun
import Data.Kind (Constraint, Type)
import Data.Monoid (Sum(..))
import Data.Proxy
import Data.Type.Equality
import GHC.Exts (withDict)
import GHC.Generics (Generic)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Permutation
import Data.Array.Nested.Mixed.Shape
type role ListS nominal representational
type ListS :: [Nat] -> (Nat -> Type) -> Type
data ListS sh f where
ZS :: ListS '[] f
-- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity
(::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f
deriving instance (forall n. Eq (f n)) => Eq (ListS sh f)
deriving instance (forall n. Ord (f n)) => Ord (ListS sh f)
infixr 3 ::$
instance (forall n. Show (f n)) => Show (ListS sh f) where
showsPrec _ = listsShow shows
instance (forall m. NFData (f m)) => NFData (ListS n f) where
rnf ZS = ()
rnf (x ::$ l) = rnf x `seq` rnf l
data UnconsListSRes f sh1 =
forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n)
listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1)
listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x)
listsUncons ZS = Nothing
-- | This checks only whether the types are equal; if the elements of the list
-- are not singletons, their values may still differ. This corresponds to
-- 'testEquality', except on the penultimate type parameter.
listsEqType :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
listsEqType ZS ZS = Just Refl
listsEqType (n ::$ sh) (m ::$ sh')
| Just Refl <- testEquality n m
, Just Refl <- listsEqType sh sh'
= Just Refl
listsEqType _ _ = Nothing
-- | This checks whether the two lists actually contain equal values. This is
-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@
-- in the @some@ package (except on the penultimate type parameter).
listsEqual :: (TestEquality f, forall n. Eq (f n)) => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
listsEqual ZS ZS = Just Refl
listsEqual (n ::$ sh) (m ::$ sh')
| Just Refl <- testEquality n m
, n == m
, Just Refl <- listsEqual sh sh'
= Just Refl
listsEqual _ _ = Nothing
listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g
listsFmap _ ZS = ZS
listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs
listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m
listsFold _ ZS = mempty
listsFold f (x ::$ xs) = f x <> listsFold f xs
listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS
listsShow f l = showString "[" . go "" l . showString "]"
where
go :: String -> ListS sh' f -> ShowS
go _ ZS = id
go prefix (x ::$ xs) = showString prefix . f x . go "," xs
listsLength :: ListS sh f -> Int
listsLength = getSum . listsFold (\_ -> Sum 1)
listsRank :: ListS sh f -> SNat (Rank sh)
listsRank ZS = SNat
listsRank (_ ::$ sh) = snatSucc (listsRank sh)
listsToList :: ListS sh (Const i) -> [i]
listsToList ZS = []
listsToList (Const i ::$ is) = i : listsToList is
listsHead :: ListS (n : sh) f -> f n
listsHead (i ::$ _) = i
listsTail :: ListS (n : sh) f -> ListS sh f
listsTail (_ ::$ sh) = sh
listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f
listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh
listsInit (_ ::$ ZS) = ZS
listsLast :: ListS (n : sh) f -> f (Last (n : sh))
listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh
listsLast (n ::$ ZS) = n
listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f
listsAppend ZS idx' = idx'
listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'
listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g)
listsZip ZS ZS = ZS
listsZip (i ::$ is) (j ::$ js) =
Fun.Pair i j ::$ listsZip is js
listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g
-> ListS sh h
listsZipWith _ ZS ZS = ZS
listsZipWith f (i ::$ is) (j ::$ js) =
f i j ::$ listsZipWith f is js
listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f
listsTakeLenPerm PNil _ = ZS
listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh
listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f
listsDropLenPerm PNil sh = sh
listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh
listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f
listsPermute PNil _ = ZS
listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) =
case listsIndex (Proxy @is') (Proxy @sh) i sh of
(item, SNat) -> item ::$ listsPermute is sh
-- TODO: remove this SNat when the KnownNat constaint in ListS is removed
listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh))
listsIndex _ _ SZ (n ::$ _) = (n, SNat)
listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f))
| Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
= listsIndex p pT i sh
listsIndex _ _ _ ZS = error "Index into empty shape"
listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f
listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh)
-- | An index into a shape-typed array.
--
-- For convenience, this contains regular 'Int's instead of bounded integers
-- (traditionally called \"@Fin@\").
type role IxS nominal representational
type IxS :: [Nat] -> Type -> Type
newtype IxS sh i = IxS (ListS sh (Const i))
deriving (Eq, Ord, Generic)
pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
pattern ZIS = IxS ZS
pattern (:.$)
:: forall {sh1} {i}.
forall n sh. (KnownNat n, n : sh ~ sh1)
=> i -> IxS sh i -> IxS sh1 i
pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i)))
where i :.$ IxS shl = IxS (Const i ::$ shl)
infixr 3 :.$
{-# COMPLETE ZIS, (:.$) #-}
type IIxS sh = IxS sh Int
instance Show i => Show (IxS sh i) where
showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l
instance Functor (IxS sh) where
fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l)
instance Foldable (IxS sh) where
foldMap f (IxS l) = listsFold (f . getConst) l
instance NFData i => NFData (IxS sh i)
ixsLength :: IxS sh i -> Int
ixsLength (IxS l) = listsLength l
ixsRank :: IxS sh i -> SNat (Rank sh)
ixsRank (IxS l) = listsRank l
ixsZero :: ShS sh -> IIxS sh
ixsZero ZSS = ZIS
ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh
ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh
ixCvtXS ZSS ZIX = ZIS
ixCvtXS (_ :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx
ixCvtSX :: IIxS sh -> IIxX (MapJust sh)
ixCvtSX ZIS = ZIX
ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh
ixsHead :: IxS (n : sh) i -> i
ixsHead (IxS list) = getConst (listsHead list)
ixsTail :: IxS (n : sh) i -> IxS sh i
ixsTail (IxS list) = IxS (listsTail list)
ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i
ixsInit (IxS list) = IxS (listsInit list)
ixsLast :: IxS (n : sh) i -> i
ixsLast (IxS list) = getConst (listsLast list)
ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i
ixsAppend = coerce (listsAppend @_ @(Const i))
ixsZip :: IxS n i -> IxS n j -> IxS n (i, j)
ixsZip ZIS ZIS = ZIS
ixsZip (i :.$ is) (j :.$ js) = (i, j) :.$ ixsZip is js
ixsZipWith :: (i -> j -> k) -> IxS n i -> IxS n j -> IxS n k
ixsZipWith _ ZIS ZIS = ZIS
ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js
ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i
ixsPermutePrefix = coerce (listsPermutePrefix @(Const i))
-- | The shape of a shape-typed array given as a list of 'SNat' values.
--
-- Note that because the shape of a shape-typed array is known statically, you
-- can also retrieve the array shape from a 'KnownShS' dictionary.
type role ShS nominal
type ShS :: [Nat] -> Type
newtype ShS sh = ShS (ListS sh SNat)
deriving (Eq, Ord, Generic)
pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh
pattern ZSS = ShS ZS
pattern (:$$)
:: forall {sh1}.
forall n sh. (KnownNat n, n : sh ~ sh1)
=> SNat n -> ShS sh -> ShS sh1
pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i))
where i :$$ ShS shl = ShS (i ::$ shl)
infixr 3 :$$
{-# COMPLETE ZSS, (:$$) #-}
instance Show (ShS sh) where
showsPrec _ (ShS l) = listsShow (shows . fromSNat) l
instance NFData (ShS sh) where
rnf (ShS ZS) = ()
rnf (ShS (SNat ::$ l)) = rnf (ShS l)
instance TestEquality ShS where
testEquality (ShS l1) (ShS l2) = listsEqType l1 l2
-- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are
-- equal if and only if values are equal.)
shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh')
shsEqual = testEquality
shsLength :: ShS sh -> Int
shsLength (ShS l) = listsLength l
shsRank :: ShS sh -> SNat (Rank sh)
shsRank (ShS l) = listsRank l
shsSize :: ShS sh -> Int
shsSize ZSS = 1
shsSize (n :$$ sh) = fromSNat' n * shsSize sh
shsToList :: ShS sh -> [Int]
shsToList ZSS = []
shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh
shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh
shCvtXS' ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
shCvtXS' (SKnown n@SNat :$% (idx :: IShX mjshT)) =
castWith (subst1 (lem Refl)) $
n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
idx)
where
lem :: forall sh1 sh' n.
Just n : sh1 :~: MapJust sh'
-> n : Tail sh' :~: sh'
lem Refl = unsafeCoerceRefl
shCvtXS' (SUnknown _ :$% _) = error "impossible"
shCvtSX :: ShS sh -> IShX (MapJust sh)
shCvtSX ZSS = ZSX
shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh
shsHead :: ShS (n : sh) -> SNat n
shsHead (ShS list) = listsHead list
shsTail :: ShS (n : sh) -> ShS sh
shsTail (ShS list) = ShS (listsTail list)
shsInit :: ShS (n : sh) -> ShS (Init (n : sh))
shsInit (ShS list) = ShS (listsInit list)
shsLast :: ShS (n : sh) -> SNat (Last (n : sh))
shsLast (ShS list) = listsLast list
shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh')
shsAppend = coerce (listsAppend @_ @SNat)
shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh)
shsTakeLen = coerce (listsTakeLenPerm @SNat)
shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh)
shsPermute = coerce (listsPermute @SNat)
shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh)
shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh)))
shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh)
shsPermutePrefix = coerce (listsPermutePrefix @SNat)
type family Product sh where
Product '[] = 1
Product (n : ns) = n * Product ns
shsProduct :: ShS sh -> SNat (Product sh)
shsProduct ZSS = SNat
shsProduct (n :$$ sh) = n `snatMul` shsProduct sh
-- | Evidence for the static part of a shape. This pops up only when you are
-- polymorphic in the element type of an array.
type KnownShS :: [Nat] -> Constraint
class KnownShS sh where knownShS :: ShS sh
instance KnownShS '[] where knownShS = ZSS
instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS
withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS k = withDict @(KnownShS sh) k
shsKnownShS :: ShS sh -> Dict KnownShS sh
shsKnownShS ZSS = Dict
shsKnownShS (SNat :$$ sh) | Dict <- shsKnownShS sh = Dict
shsOrthotopeShape :: ShS sh -> Dict O.Shape sh
shsOrthotopeShape ZSS = Dict
shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict
-- | Untyped: length is checked at runtime.
instance KnownShS sh => IsList (ListS sh (Const i)) where
type Item (ListS sh (Const i)) = i
fromList topl = go (knownShS @sh) topl
where
go :: ShS sh' -> [i] -> ListS sh' (Const i)
go ZSS [] = ZS
go (_ :$$ sh) (i : is) = Const i ::$ go sh is
go _ _ = error $ "IsList(ListS): Mismatched list length (type says "
++ show (shsLength (knownShS @sh)) ++ ", list has length "
++ show (length topl) ++ ")"
toList = listsToList
-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
instance KnownShS sh => IsList (IxS sh i) where
type Item (IxS sh i) = i
fromList = IxS . IsList.fromList
toList = Foldable.toList
-- | Untyped: length and values are checked at runtime.
instance KnownShS sh => IsList (ShS sh) where
type Item (ShS sh) = Int
fromList topl = ShS (go (knownShS @sh) topl)
where
go :: ShS sh' -> [Int] -> ListS sh' SNat
go ZSS [] = ZS
go (sn :$$ sh) (i : is)
| i == fromSNat' sn = sn ::$ go sh is
| otherwise = error $ "IsList(ShS): Value does not match typing (type says "
++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
go _ _ = error $ "IsList(ShS): Mismatched list length (type says "
++ show (shsLength (knownShS @sh)) ++ ", list has length "
++ show (length topl) ++ ")"
toList = shsToList
|