aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal/Shape.hs
blob: 59f2c9ad4ee4f12c55f92dbffe8d545010a8fe8b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Nested.Internal.Shape where

import Data.Array.Shape qualified as O
import Data.Array.Mixed.Types
import Data.Coerce (coerce)
import Data.Foldable qualified as Foldable
import Data.Functor.Const
import Data.Kind (Type, Constraint)
import Data.Monoid (Sum(..))
import Data.Proxy
import Data.Type.Equality
import GHC.Exts (withDict)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
import GHC.TypeNats qualified as TN

import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Permutation
import Data.Array.Mixed.Shape


type role ListR nominal representational
type ListR :: Nat -> Type -> Type
data ListR n i where
  ZR :: ListR 0 i
  (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
deriving instance Eq i => Eq (ListR n i)
deriving instance Ord i => Ord (ListR n i)
deriving instance Functor (ListR n)
deriving instance Foldable (ListR n)
infixr 3 :::

instance Show i => Show (ListR n i) where
  showsPrec _ = listrShow shows

data UnconsListRRes i n1 =
  forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i
listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1)
listrUncons (i ::: sh') = Just (UnconsListRRes sh' i)
listrUncons ZR = Nothing

listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS
listrShow f l = showString "[" . go "" l . showString "]"
  where
    go :: String -> ListR n' i -> ShowS
    go _ ZR = id
    go prefix (x ::: xs) = showString prefix . f x . go "," xs

listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i
listrAppend ZR sh = sh
listrAppend (x ::: xs) sh = x ::: listrAppend xs sh

listrFromList :: [i] -> (forall n. ListR n i -> r) -> r
listrFromList [] k = k ZR
listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l)

listrHead :: ListR (n + 1) i -> i
listrHead (i ::: _) = i
listrHead ZR = error "unreachable"

listrTail :: ListR (n + 1) i -> ListR n i
listrTail (_ ::: sh) = sh
listrTail ZR = error "unreachable"

listrInit :: ListR (n + 1) i -> ListR n i
listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh
listrInit (_ ::: ZR) = ZR
listrInit ZR = error "unreachable"

listrLast :: ListR (n + 1) i -> i
listrLast (_ ::: sh@(_ ::: _)) = listrLast sh
listrLast (n ::: ZR) = n
listrLast ZR = error "unreachable"

listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
listrIndex SZ (x ::: _) = x
listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
listrIndex _ ZR = error "k + 1 <= 0"

listrRank :: ListR n i -> SNat n
listrRank ZR = SNat
listrRank (_ ::: sh) = snatSucc (listrRank sh)

listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
listrPermutePrefix = \perm sh ->
  listrFromList perm $ \sperm ->
    case (listrRank sperm, listrRank sh) of
      (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of
        LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
        EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
        GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
                       ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
  where
    listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
    listrSplitAt SZ sh = (ZR, sh)
    listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
    listrSplitAt SS{} ZR = error "m' + 1 <= 0"

    applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i
    applyPermRFull _ ZR _ = ZR
    applyPermRFull sm@SNat (i ::: perm) l =
      TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
        case cmpNat (SNat @(idx + 1)) sm of
          LTI -> listrIndex si l ::: applyPermRFull sm perm l
          EQI -> listrIndex si l ::: applyPermRFull sm perm l
          GTI -> error "listrPermutePrefix: Index in permutation out of range"


-- | An index into a rank-typed array.
type role IxR nominal representational
type IxR :: Nat -> Type -> Type
newtype IxR n i = IxR (ListR n i)
  deriving (Eq, Ord)
  deriving newtype (Functor, Foldable)

pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
pattern ZIR = IxR ZR

pattern (:.:)
  :: forall {n1} {i}.
     forall n. (n + 1 ~ n1)
  => i -> IxR n i -> IxR n1 i
pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i))
  where i :.: IxR sh = IxR (i ::: sh)
infixr 3 :.:

{-# COMPLETE ZIR, (:.:) #-}

type IIxR n = IxR n Int

instance Show i => Show (IxR n i) where
  showsPrec _ (IxR l) = listrShow shows l

ixrZero :: SNat n -> IIxR n
ixrZero SZ = ZIR
ixrZero (SS n) = 0 :.: ixrZero n

ixCvtXR :: IIxX sh -> IIxR (Rank sh)
ixCvtXR ZIX = ZIR
ixCvtXR (n :.% idx) = n :.: ixCvtXR idx

ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
ixCvtRX ZIR = ZIX
ixCvtRX (n :.: (idx :: IxR m Int)) =
  castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m))
    (n :.% ixCvtRX idx)

ixrHead :: IxR (n + 1) i -> i
ixrHead (IxR list) = listrHead list

ixrTail :: IxR (n + 1) i -> IxR n i
ixrTail (IxR list) = IxR (listrTail list)

ixrInit :: IxR (n + 1) i -> IxR n i
ixrInit (IxR list) = IxR (listrInit list)

ixrLast :: IxR (n + 1) i -> i
ixrLast (IxR list) = listrLast list

ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i
ixrAppend = coerce (listrAppend @_ @i)

ixrRank :: IxR n i -> SNat n
ixrRank (IxR sh) = listrRank sh

ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
ixrPermutePrefix = coerce (listrPermutePrefix @i)


type role ShR nominal representational
type ShR :: Nat -> Type -> Type
newtype ShR n i = ShR (ListR n i)
  deriving (Eq, Ord)
  deriving newtype (Functor, Foldable)

pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
pattern ZSR = ShR ZR

pattern (:$:)
  :: forall {n1} {i}.
     forall n. (n + 1 ~ n1)
  => i -> ShR n i -> ShR n1 i
pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i))
  where i :$: (ShR sh) = ShR (i ::: sh)
infixr 3 :$:

{-# COMPLETE ZSR, (:$:) #-}

type IShR n = ShR n Int

instance Show i => Show (ShR n i) where
  showsPrec _ (ShR l) = listrShow shows l

shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n
shCvtXR' ZSX =
  castWith (subst2 (unsafeCoerceRefl :: 0 :~: n))
    ZSR
shCvtXR' (n :$% (idx :: IShX sh))
  | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) =
  castWith (subst2 (lem1 @sh Refl))
    (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx))
  where
    lem1 :: forall sh' n' k.
            k : sh' :~: Replicate n' Nothing
         -> Rank sh' + 1 :~: n'
    lem1 Refl = unsafeCoerceRefl

    lem2 :: k : sh :~: Replicate n Nothing
         -> sh :~: Replicate (Rank sh) Nothing
    lem2 Refl = unsafeCoerceRefl

shCvtRX :: IShR n -> IShX (Replicate n Nothing)
shCvtRX ZSR = ZSX
shCvtRX (n :$: (idx :: ShR m Int)) =
  castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m))
    (SUnknown n :$% shCvtRX idx)

-- | The number of elements in an array described by this shape.
shrSize :: IShR n -> Int
shrSize ZSR = 1
shrSize (n :$: sh) = n * shrSize sh

shrHead :: ShR (n + 1) i -> i
shrHead (ShR list) = listrHead list

shrTail :: ShR (n + 1) i -> ShR n i
shrTail (ShR list) = ShR (listrTail list)

shrInit :: ShR (n + 1) i -> ShR n i
shrInit (ShR list) = ShR (listrInit list)

shrLast :: ShR (n + 1) i -> i
shrLast (ShR list) = listrLast list

shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i
shrAppend = coerce (listrAppend @_ @i)

-- | This function can also be used to conjure up a 'KnownNat' dictionary;
-- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern
-- synonym yields 'KnownNat' evidence.
shrRank :: ShR n i -> SNat n
shrRank (ShR sh) = listrRank sh

shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i
shrPermutePrefix = coerce (listrPermutePrefix @i)


-- | Untyped: length is checked at runtime.
instance KnownNat n => IsList (ListR n i) where
  type Item (ListR n i) = i
  fromList topl = go (SNat @n) topl
    where
      go :: SNat n' -> [i] -> ListR n' i
      go SZ [] = ZR
      go (SS n) (i : is) = i ::: go n is
      go _ _ = error $ "IsList(ListR): Mismatched list length (type says "
                         ++ show (fromSNat (SNat @n)) ++ ", list has length "
                         ++ show (length topl) ++ ")"
  toList = Foldable.toList

-- | Untyped: length is checked at runtime.
instance KnownNat n => IsList (IxR n i) where
  type Item (IxR n i) = i
  fromList = IxR . IsList.fromList
  toList = Foldable.toList

-- | Untyped: length is checked at runtime.
instance KnownNat n => IsList (ShR n i) where
  type Item (ShR n i) = i
  fromList = ShR . IsList.fromList
  toList = Foldable.toList


type role ListS nominal representational
type ListS :: [Nat] -> (Nat -> Type) -> Type
data ListS sh f where
  ZS :: ListS '[] f
  -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity
  (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f
deriving instance (forall n. Eq (f n)) => Eq (ListS sh f)
deriving instance (forall n. Ord (f n)) => Ord (ListS sh f)
infixr 3 ::$

instance (forall n. Show (f n)) => Show (ListS sh f) where
  showsPrec _ = listsShow shows

data UnconsListSRes f sh1 =
  forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n)
listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1)
listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x)
listsUncons ZS = Nothing

listsEq :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
listsEq ZS ZS = Just Refl
listsEq (n ::$ sh) (m ::$ sh')
  | Just Refl <- testEquality n m
  , Just Refl <- listsEq sh sh'
  = Just Refl
listsEq _ _ = Nothing

listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g
listsFmap _ ZS = ZS
listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs

listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m
listsFold _ ZS = mempty
listsFold f (x ::$ xs) = f x <> listsFold f xs

listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS
listsShow f l = showString "[" . go "" l . showString "]"
  where
    go :: String -> ListS sh' f -> ShowS
    go _ ZS = id
    go prefix (x ::$ xs) = showString prefix . f x . go "," xs

listsToList :: ListS sh (Const i) -> [i]
listsToList ZS = []
listsToList (Const i ::$ is) = i : listsToList is

listsHead :: ListS (n : sh) f -> f n
listsHead (i ::$ _) = i

listsTail :: ListS (n : sh) f -> ListS sh f
listsTail (_ ::$ sh) = sh

listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f
listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh
listsInit (_ ::$ ZS) = ZS

listsLast :: ListS (n : sh) f -> f (Last (n : sh))
listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh
listsLast (n ::$ ZS) = n

listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f
listsAppend ZS idx' = idx'
listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'

listsRank :: ListS sh f -> SNat (Rank sh)
listsRank ZS = SNat
listsRank (_ ::$ sh) = snatSucc (listsRank sh)

listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f
listsTakeLenPerm PNil _ = ZS
listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh
listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"

listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f
listsDropLenPerm PNil sh = sh
listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh
listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"

listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f
listsPermute PNil _ = ZS
listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) =
  case listsIndex (Proxy @is') (Proxy @sh) i sh of
    (item, SNat) -> item ::$ listsPermute is sh

-- TODO: remove this SNat when the KnownNat constaint in ListS is removed
listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh))
listsIndex _ _ SZ (n ::$ _) = (n, SNat)
listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f))
  | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
  = listsIndex p pT i sh
listsIndex _ _ _ ZS = error "Index into empty shape"

listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f
listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh)


-- | An index into a shape-typed array.
--
-- For convenience, this contains regular 'Int's instead of bounded integers
-- (traditionally called \"@Fin@\").
type role IxS nominal representational
type IxS :: [Nat] -> Type -> Type
newtype IxS sh i = IxS (ListS sh (Const i))
  deriving (Eq, Ord)

pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
pattern ZIS = IxS ZS

pattern (:.$)
  :: forall {sh1} {i}.
     forall n sh. (KnownNat n, n : sh ~ sh1)
  => i -> IxS sh i -> IxS sh1 i
pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i)))
  where i :.$ IxS shl = IxS (Const i ::$ shl)
infixr 3 :.$

{-# COMPLETE ZIS, (:.$) #-}

type IIxS sh = IxS sh Int

instance Show i => Show (IxS sh i) where
  showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l

instance Functor (IxS sh) where
  fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l)

instance Foldable (IxS sh) where
  foldMap f (IxS l) = listsFold (f . getConst) l

ixsZero :: ShS sh -> IIxS sh
ixsZero ZSS = ZIS
ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh

ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh
ixCvtXS ZSS ZIX = ZIS
ixCvtXS (_ :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx

ixCvtSX :: IIxS sh -> IIxX (MapJust sh)
ixCvtSX ZIS = ZIX
ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh

ixsHead :: IxS (n : sh) i -> i
ixsHead (IxS list) = getConst (listsHead list)

ixsTail :: IxS (n : sh) i -> IxS sh i
ixsTail (IxS list) = IxS (listsTail list)

ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i
ixsInit (IxS list) = IxS (listsInit list)

ixsLast :: IxS (n : sh) i -> i
ixsLast (IxS list) = getConst (listsLast list)

ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i
ixsAppend = coerce (listsAppend @_ @(Const i))

ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i
ixsPermutePrefix = coerce (listsPermutePrefix @(Const i))


-- | The shape of a shape-typed array given as a list of 'SNat' values.
--
-- Note that because the shape of a shape-typed array is known statically, you
-- can also retrieve the array shape from a 'KnownShS' dictionary.
type role ShS nominal
type ShS :: [Nat] -> Type
newtype ShS sh = ShS (ListS sh SNat)
  deriving (Eq, Ord)

pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh
pattern ZSS = ShS ZS

pattern (:$$)
  :: forall {sh1}.
     forall n sh. (KnownNat n, n : sh ~ sh1)
  => SNat n -> ShS sh -> ShS sh1
pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i))
  where i :$$ ShS shl = ShS (i ::$ shl)

infixr 3 :$$

{-# COMPLETE ZSS, (:$$) #-}

instance Show (ShS sh) where
  showsPrec _ (ShS l) = listsShow (shows . fromSNat) l

instance TestEquality ShS where
  testEquality (ShS l1) (ShS l2) = listsEq l1 l2

shsLength :: ShS sh -> Int
shsLength (ShS l) = getSum (listsFold (\_ -> Sum 1) l)

shsRank :: ShS sh -> SNat (Rank sh)
shsRank (ShS l) = listsRank l

shsToList :: ShS sh -> [Int]
shsToList ZSS = []
shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh

shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh
shCvtXS' ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
shCvtXS' (SKnown n@SNat :$% (idx :: IShX mjshT)) =
  castWith (subst1 (lem Refl)) $
    n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
                                 idx)
  where
    lem :: forall sh1 sh' n.
           Just n : sh1 :~: MapJust sh'
        -> n : Tail sh' :~: sh'
    lem Refl = unsafeCoerceRefl
shCvtXS' (SUnknown _ :$% _) = error "impossible"

shCvtSX :: ShS sh -> IShX (MapJust sh)
shCvtSX ZSS = ZSX
shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh

shsHead :: ShS (n : sh) -> SNat n
shsHead (ShS list) = listsHead list

shsTail :: ShS (n : sh) -> ShS sh
shsTail (ShS list) = ShS (listsTail list)

shsInit :: ShS (n : sh) -> ShS (Init (n : sh))
shsInit (ShS list) = ShS (listsInit list)

shsLast :: ShS (n : sh) -> SNat (Last (n : sh))
shsLast (ShS list) = listsLast list

shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh')
shsAppend = coerce (listsAppend @_ @SNat)

shsSize :: ShS sh -> Int
shsSize ZSS = 1
shsSize (n :$$ sh) = fromSNat' n * shsSize sh

shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh)
shsTakeLen = coerce (listsTakeLenPerm @SNat)

shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh)
shsPermute = coerce (listsPermute @SNat)

shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh)
shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh)))

shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh)
shsPermutePrefix = coerce (listsPermutePrefix @SNat)

type family Product sh where
  Product '[] = 1
  Product (n : ns) = n * Product ns

shsProduct :: ShS sh -> SNat (Product sh)
shsProduct ZSS = SNat
shsProduct (n :$$ sh) = n `snatMul` shsProduct sh

-- | Evidence for the static part of a shape. This pops up only when you are
-- polymorphic in the element type of an array.
type KnownShS :: [Nat] -> Constraint
class KnownShS sh where knownShS :: ShS sh
instance KnownShS '[] where knownShS = ZSS
instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS

withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r
withKnownShS sh = withDict @(KnownShS sh) sh

shsKnownShS :: ShS sh -> Dict KnownShS sh
shsKnownShS ZSS = Dict
shsKnownShS (SNat :$$ sh) | Dict <- shsKnownShS sh = Dict

shsOrthotopeShape :: ShS sh -> Dict O.Shape sh
shsOrthotopeShape ZSS = Dict
shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict


-- | Untyped: length is checked at runtime.
instance KnownShS sh => IsList (ListS sh (Const i)) where
  type Item (ListS sh (Const i)) = i
  fromList topl = go (knownShS @sh) topl
    where
      go :: ShS sh' -> [i] -> ListS sh' (Const i)
      go ZSS [] = ZS
      go (_ :$$ sh) (i : is) = Const i ::$ go sh is
      go _ _ = error $ "IsList(ListS): Mismatched list length (type says "
                         ++ show (shsLength (knownShS @sh)) ++ ", list has length "
                         ++ show (length topl) ++ ")"
  toList = listsToList

-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
instance KnownShS sh => IsList (IxS sh i) where
  type Item (IxS sh i) = i
  fromList = IxS . IsList.fromList
  toList = Foldable.toList

-- | Untyped: length and values are checked at runtime.
instance KnownShS sh => IsList (ShS sh) where
  type Item (ShS sh) = Int
  fromList topl = ShS (go (knownShS @sh) topl)
    where
      go :: ShS sh' -> [Int] -> ListS sh' SNat
      go ZSS [] = ZS
      go (sn :$$ sh) (i : is)
        | i == fromSNat' sn = sn ::$ go sh is
        | otherwise = error $ "IsList(ShS): Value does not match typing (type says "
                                ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
      go _ _ = error $ "IsList(ShS): Mismatched list length (type says "
                         ++ show (shsLength (knownShS @sh)) ++ ", list has length "
                         ++ show (length topl) ++ ")"
  toList = shsToList