aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal/Shaped.hs
blob: d8c80d13822571f3ecd70d5c7a1871e7eaa3c568 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Nested.Internal.Shaped where

import Prelude hiding (mappend, mconcat)

import Control.DeepSeq (NFData)
import Control.Monad.ST
import Data.Array.Internal.ShapedS qualified as SS
import Data.Array.Internal.ShapedG qualified as SG
import Data.Array.Internal.RankedS qualified as RS
import Data.Array.Internal.RankedG qualified as RG
import Data.Bifunctor (first)
import Data.Coerce (coerce)
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
import Data.Type.Equality
import Data.Vector.Storable qualified as VS
import Foreign.Storable (Storable)
import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp)
import GHC.TypeLits

import Data.Array.Mixed.XArray (XArray)
import Data.Array.Mixed.XArray qualified as X
import Data.Array.Mixed.Internal.Arith
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Permutation
import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
import Data.Array.Nested.Internal.Lemmas
import Data.Array.Nested.Internal.Mixed
import Data.Array.Nested.Internal.Shape


-- | A shape-typed array: the full shape of the array (the sizes of its
-- dimensions) is represented on the type level as a list of 'Nat's. Note that
-- these are "GHC.TypeLits" naturals, because we do not need induction over
-- them and we want very large arrays to be possible.
--
-- Like for 'Ranked', the valid elements are described by the 'Elt' type class,
-- and 'Shaped' itself is again an instance of 'Elt' as well.
--
-- 'Shaped' is a newtype around a 'Mixed' of 'Just's.
type Shaped :: [Nat] -> Type -> Type
newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a)
deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a)
deriving instance NFData (Mixed (MapJust sh) a) => NFData (Shaped sh a)

instance (Show a, Elt a) => Show (Shaped sh a) where
  showsPrec d arr = showParen (d > 10) $
    showString "sfromListLinear " . shows (shsToList (sshape arr)) . showString " "
      . shows (stoListLinear arr)

-- just unwrap the newtype and defer to the general instance for nested arrays
newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a))
deriving via (ShowViaToListLinear sh (Shaped sh' a)) instance (Show a, Elt a) => Show (Mixed sh (Shaped sh' a))

newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a))

instance Elt a => Elt (Shaped sh a) where
  mshape (M_Shaped arr) = mshape arr
  mindex (M_Shaped arr) i = Shaped (mindex arr i)

  mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
  mindexPartial (M_Shaped arr) i =
    coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
      mindexPartial arr i

  mscalar (Shaped x) = M_Shaped (M_Nest ZSX x)

  mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a)
  mfromListOuter l = M_Shaped (mfromListOuter (coerce l))

  mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
  mtoListOuter (M_Shaped arr)
    = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr)

  mlift :: forall sh1 sh2.
           StaticShX sh2
        -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
        -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a)
  mlift ssh2 f (M_Shaped arr) =
    coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
      mlift ssh2 f arr

  mlift2 :: forall sh1 sh2 sh3.
            StaticShX sh3
         -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
         -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a)
  mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) =
    coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
      mlift2 ssh3 f arr1 arr2

  mliftL :: forall sh1 sh2.
            StaticShX sh2
         -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
         -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a))
  mliftL ssh2 f l =
    coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a)))
           @(NonEmpty (Mixed sh2 (Shaped sh a))) $
      mliftL ssh2 f (coerce l)

  mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr)

  mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr)

  mconcat l = M_Shaped (mconcat (coerce l))

  type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)

  mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr)

  mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2

  mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t

  mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"

  mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
  mvecsWrite sh idx (Shaped arr) vecs =
    mvecsWrite sh idx arr
      (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
         vecs)

  mvecsWritePartial :: forall sh1 sh2 s.
                       IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
                    -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
                    -> ST s ()
  mvecsWritePartial sh idx arr vecs =
    mvecsWritePartial sh idx
      (coerce @(Mixed sh2 (Shaped sh a))
              @(Mixed sh2 (Mixed (MapJust sh) a))
         arr)
      (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a))
              @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
         vecs)

  mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
  mvecsFreeze sh vecs =
    coerce @(Mixed sh' (Mixed (MapJust sh) a))
           @(Mixed sh' (Shaped sh a))
      <$> mvecsFreeze sh
            (coerce @(MixedVecs s sh' (Shaped sh a))
                    @(MixedVecs s sh' (Mixed (MapJust sh) a))
                    vecs)

instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
  memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
  memptyArray i
    | Dict <- lemKnownMapJust (Proxy @sh)
    = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
        memptyArray i

  mvecsUnsafeNew idx (Shaped arr)
    | Dict <- lemKnownMapJust (Proxy @sh)
    = MV_Shaped <$> mvecsUnsafeNew idx arr

  mvecsNewEmpty _
    | Dict <- lemKnownMapJust (Proxy @sh)
    = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))


arithPromoteShaped :: forall sh a b.
                      (forall shx. Mixed shx a -> Mixed shx b)
                   -> Shaped sh a -> Shaped sh b
arithPromoteShaped = coerce

arithPromoteShaped2 :: forall sh a b c.
                       (forall shx. Mixed shx a -> Mixed shx b -> Mixed shx c)
                    -> Shaped sh a -> Shaped sh b -> Shaped sh c
arithPromoteShaped2 = coerce

-- | TODO: 'KnownShS' is only there for 'fromInteger'.
instance (NumElt a, PrimElt a, Num a, KnownShS sh) => Num (Shaped sh a) where
  (+) = arithPromoteShaped2 (+)
  (-) = arithPromoteShaped2 (-)
  (*) = arithPromoteShaped2 (*)
  negate = arithPromoteShaped negate
  abs = arithPromoteShaped abs
  signum = arithPromoteShaped signum
  fromInteger =
    case knownShS @sh of
      ZSS -> sscalar . fromInteger
      _ -> error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal/sscalar"

-- | TODO: 'KnownShS' is only there for 'fromRational'.
instance (FloatElt a, NumElt a, PrimElt a, Fractional a, KnownShS sh) => Fractional (Shaped sh a) where
  fromRational =
    case knownShS @sh of
      ZSS -> sscalar . fromRational
      _ -> error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal/sscalar"
  recip = arithPromoteShaped recip
  (/) = arithPromoteShaped2 (/)

-- | TODO: 'KnownShS' is only there for 'pi'.
instance (FloatElt a, NumElt a, PrimElt a, Floating a, KnownShS sh) => Floating (Shaped sh a) where
  pi =
    case knownShS @sh of
      ZSS -> sscalar pi
      _ -> error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal/sscalar"
  exp = arithPromoteShaped exp
  log = arithPromoteShaped log
  sqrt = arithPromoteShaped sqrt
  (**) = arithPromoteShaped2 (**)
  logBase = arithPromoteShaped2 logBase
  sin = arithPromoteShaped sin
  cos = arithPromoteShaped cos
  tan = arithPromoteShaped tan
  asin = arithPromoteShaped asin
  acos = arithPromoteShaped acos
  atan = arithPromoteShaped atan
  sinh = arithPromoteShaped sinh
  cosh = arithPromoteShaped cosh
  tanh = arithPromoteShaped tanh
  asinh = arithPromoteShaped asinh
  acosh = arithPromoteShaped acosh
  atanh = arithPromoteShaped atanh
  log1p = arithPromoteShaped GHC.Float.log1p
  expm1 = arithPromoteShaped GHC.Float.expm1
  log1pexp = arithPromoteShaped GHC.Float.log1pexp
  log1mexp = arithPromoteShaped GHC.Float.log1mexp


sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
sshape (Shaped arr) = shCvtXS' (mshape arr)

srank :: Elt a => Shaped sh a -> SNat (Rank sh)
srank = shsRank . sshape

-- | The total number of elements in the array.
ssize :: Elt a => Shaped sh a -> Int
ssize = shsSize . sshape

sindex :: Elt a => Shaped sh a -> IIxS sh -> a
sindex (Shaped arr) idx = mindex arr (ixCvtSX idx)

shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh
shsTakeIx _ _ ZIS = ZSS
shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx

sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a
sindexPartial sarr@(Shaped arr) idx =
  Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
            (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr)
            (ixCvtSX idx))

-- | __WARNING__: All values returned from the function must have equal shape.
-- See the documentation of 'mgenerate' for more details.
sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a
sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh))

-- | See the documentation of 'mlift'.
slift :: forall sh1 sh2 a. Elt a
      => ShS sh2
      -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
      -> Shaped sh1 a -> Shaped sh2 a
slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr)

-- | See the documentation of 'mlift'.
slift2 :: forall sh1 sh2 sh3 a. Elt a
       => ShS sh3
       -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b)
       -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a
slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2)

ssumOuter1P :: forall sh n a. (Storable a, NumElt a)
            => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr)

ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
           => Shaped (n : sh) a -> Shaped sh a
ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive

ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a
ssumAllPrim (Shaped arr) = msumAllPrim arr

stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a)
           => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a
stranspose perm sarr@(Shaped arr)
  | Refl <- lemRankMapJust (sshape sarr)
  , Refl <- lemTakeLenMapJust perm (sshape sarr)
  , Refl <- lemDropLenMapJust perm (sshape sarr)
  , Refl <- lemPermuteMapJust perm (shsTakeLen perm (sshape sarr))
  , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh))
  = Shaped (mtranspose perm arr)

sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a
sappend = coerce mappend

sscalar :: Elt a => a -> Shaped '[] a
sscalar x = Shaped (mscalar x)

sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a)
sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v)

sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a
sfromVector sh v = sfromPrimitive (sfromVectorP sh v)

stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a
stoVectorP = coerce mtoVectorP

stoVector :: PrimElt a => Shaped sh a -> VS.Vector a
stoVector = coerce mtoVector

sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
sfromListOuter sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromListOuter (coerce l))

sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
sfromList1 sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1

sfromList1Prim :: (PrimElt a, Elt a) => SNat n -> [a] -> Shaped '[n] a
sfromList1Prim sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1Prim

stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
stoListOuter (Shaped arr) = coerce (mtoListOuter arr)

stoList1 :: Elt a => Shaped '[n] a -> [a]
stoList1 = map sunScalar . stoListOuter

sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
sfromListPrim sn l
  | Refl <- lemAppNil @'[Just n]
  = let ssh = SUnknown () :!% ZKX
        xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l)
    in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr

sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a
sfromListPrimLinear sh l =
  let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
  in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr)

sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a
sfromListLinear sh l = Shaped (mfromListLinear (shCvtSX sh) l)

stoListLinear :: Elt a => Shaped sh a -> [a]
stoListLinear (Shaped arr) = mtoListLinear arr

sfromOrthotope :: PrimElt a => ShS sh -> SS.Array sh a -> Shaped sh a
sfromOrthotope sh (SS.A (SG.A arr)) =
  Shaped (fromPrimitive (M_Primitive (shCvtSX sh) (X.XArray (RS.A (RG.A (shsToList sh) arr)))))

stoOrthotope :: PrimElt a => Shaped sh a -> SS.Array sh a
stoOrthotope (stoPrimitive -> Shaped (M_Primitive _ (X.XArray (RS.A (RG.A _ arr))))) = SS.A (SG.A arr)

sunScalar :: Elt a => Shaped '[] a -> a
sunScalar arr = sindex arr ZIS

snest :: forall sh sh' a. Elt a => ShS sh -> Shaped (sh ++ sh') a -> Shaped sh (Shaped sh' a)
snest sh arr
  | Refl <- lemMapJustApp sh (Proxy @sh')
  = coerce (mnest (ssxFromShape (shCvtSX sh)) (coerce arr))

sunNest :: forall sh sh' a. Elt a => Shaped sh (Shaped sh' a) -> Shaped (sh ++ sh') a
sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr)))
  | Refl <- lemMapJustApp (sshape sarr) (Proxy @sh')
  = Shaped arr

srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
         => ShS sh -> ShS sh2
         -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b))
         -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b)
srerankP sh sh2 f sarr@(Shaped arr)
  | Refl <- lemMapJustApp sh (Proxy @sh1)
  , Refl <- lemMapJustApp sh (Proxy @sh2)
  = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh))))
                     (shCvtSX sh2)
                     (\a -> let Shaped r = f (Shaped a) in r)
                     arr)

srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
        => ShS sh -> ShS sh2
        -> (Shaped sh1 a -> Shaped sh2 b)
        -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b
srerank sh sh2 f (stoPrimitive -> arr) =
  sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr

sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a
sreplicate sh (Shaped arr)
  | Refl <- lemMapJustApp sh (Proxy @sh')
  = Shaped (mreplicate (shCvtSX sh) arr)

sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x)

sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a
sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x)

sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a
sslice i n@SNat arr =
  let _ :$$ sh = sshape arr
  in slift (n :$$ sh) (\_ -> X.slice i n) arr

srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a
srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr

sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shaped sh' a
sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr)

sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a
sflatten arr =
  case shsProduct (sshape arr) of  -- TODO: simplify when removing the KnownNat stuff
    n@SNat -> sreshape (n :$$ ZSS) arr

siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a
siota sn = Shaped (miota sn)

-- | Throws if the array is empty.
sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminIndexPrim arr)

-- | Throws if the array is empty.
smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr)

sdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
           => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a
sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2)
  | Refl <- lemInitApp (Proxy @sh) (Proxy @n)
  , Refl <- lemLastApp (Proxy @sh) (Proxy @n)
  = case sshape sarr1 of
      _ :$$ _
        | Refl <- lemMapJustApp (shsInit (sshape sarr1)) (Proxy @'[n])
        -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2)
      _ -> error "unreachable"

-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
-- Prefer 'sdot1Inner' if applicable.
sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a
sdot = coerce mdot

stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a)
stoXArrayPrimP (Shaped arr) = first shCvtXS' (mtoXArrayPrimP arr)

stoXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a)
stoXArrayPrim (Shaped arr) = first shCvtXS' (mtoXArrayPrim arr)

sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a)
sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr)

sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a
sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr)

sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a
sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr)

stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a)
stoPrimitive (Shaped arr) = Shaped (toPrimitive arr)

mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
              => Mixed sh a -> ShS sh' -> Shaped sh' a
mcastToShaped arr targetsh
  | Refl <- lemAppNil @sh
  , Refl <- lemAppNil @(MapJust sh')
  , Refl <- lemRankMapJust targetsh
  = Shaped (mcast (ssxFromShape (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr)

stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a
stoMixed (Shaped arr) = arr