aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Ranked/Shape.hs
blob: 1c0b9eb5df2bc6bb6516226c9b1eda257b09e298 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Nested.Ranked.Shape where

import Control.DeepSeq (NFData(..))
import Data.Array.Mixed.Types
import Data.Coerce (coerce)
import Data.Foldable qualified as Foldable
import Data.Kind (Type)
import Data.Proxy
import Data.Type.Equality
import GHC.Generics (Generic)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
import GHC.TypeNats qualified as TN

import Data.Array.Mixed.Lemmas
import Data.Array.Nested.Mixed.Shape


type role ListR nominal representational
type ListR :: Nat -> Type -> Type
data ListR n i where
  ZR :: ListR 0 i
  (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
deriving instance Eq i => Eq (ListR n i)
deriving instance Ord i => Ord (ListR n i)
deriving instance Functor (ListR n)
deriving instance Foldable (ListR n)
infixr 3 :::

instance Show i => Show (ListR n i) where
  showsPrec _ = listrShow shows

instance NFData i => NFData (ListR n i) where
  rnf ZR = ()
  rnf (x ::: l) = rnf x `seq` rnf l

data UnconsListRRes i n1 =
  forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i
listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1)
listrUncons (i ::: sh') = Just (UnconsListRRes sh' i)
listrUncons ZR = Nothing

-- | This checks only whether the ranks are equal, not whether the actual
-- values are.
listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n')
listrEqRank ZR ZR = Just Refl
listrEqRank (_ ::: sh) (_ ::: sh')
  | Just Refl <- listrEqRank sh sh'
  = Just Refl
listrEqRank _ _ = Nothing

-- | This compares the lists for value equality.
listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n')
listrEqual ZR ZR = Just Refl
listrEqual (i ::: sh) (j ::: sh')
  | Just Refl <- listrEqual sh sh'
  , i == j
  = Just Refl
listrEqual _ _ = Nothing

listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS
listrShow f l = showString "[" . go "" l . showString "]"
  where
    go :: String -> ListR n' i -> ShowS
    go _ ZR = id
    go prefix (x ::: xs) = showString prefix . f x . go "," xs

listrLength :: ListR n i -> Int
listrLength = length

listrRank :: ListR n i -> SNat n
listrRank ZR = SNat
listrRank (_ ::: sh) = snatSucc (listrRank sh)

listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i
listrAppend ZR sh = sh
listrAppend (x ::: xs) sh = x ::: listrAppend xs sh

listrFromList :: [i] -> (forall n. ListR n i -> r) -> r
listrFromList [] k = k ZR
listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l)

listrHead :: ListR (n + 1) i -> i
listrHead (i ::: _) = i
listrHead ZR = error "unreachable"

listrTail :: ListR (n + 1) i -> ListR n i
listrTail (_ ::: sh) = sh
listrTail ZR = error "unreachable"

listrInit :: ListR (n + 1) i -> ListR n i
listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh
listrInit (_ ::: ZR) = ZR
listrInit ZR = error "unreachable"

listrLast :: ListR (n + 1) i -> i
listrLast (_ ::: sh@(_ ::: _)) = listrLast sh
listrLast (n ::: ZR) = n
listrLast ZR = error "unreachable"

listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
listrIndex SZ (x ::: _) = x
listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
listrIndex _ ZR = error "k + 1 <= 0"

listrZip :: ListR n i -> ListR n j -> ListR n (i, j)
listrZip ZR ZR = ZR
listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest
listrZip _ _ = error "listrZip: impossible pattern needlessly required"

listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k
listrZipWith _ ZR ZR = ZR
listrZipWith f (i ::: irest) (j ::: jrest) =
  f i j ::: listrZipWith f irest jrest
listrZipWith _ _ _ =
  error "listrZipWith: impossible pattern needlessly required"

listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
listrPermutePrefix = \perm sh ->
  listrFromList perm $ \sperm ->
    case (listrRank sperm, listrRank sh) of
      (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of
        LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
        EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
        GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
                       ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
  where
    listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
    listrSplitAt SZ sh = (ZR, sh)
    listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
    listrSplitAt SS{} ZR = error "m' + 1 <= 0"

    applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i
    applyPermRFull _ ZR _ = ZR
    applyPermRFull sm@SNat (i ::: perm) l =
      TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
        case cmpNat (SNat @(idx + 1)) sm of
          LTI -> listrIndex si l ::: applyPermRFull sm perm l
          EQI -> listrIndex si l ::: applyPermRFull sm perm l
          GTI -> error "listrPermutePrefix: Index in permutation out of range"


-- | An index into a rank-typed array.
type role IxR nominal representational
type IxR :: Nat -> Type -> Type
newtype IxR n i = IxR (ListR n i)
  deriving (Eq, Ord, Generic)
  deriving newtype (Functor, Foldable)

pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
pattern ZIR = IxR ZR

pattern (:.:)
  :: forall {n1} {i}.
     forall n. (n + 1 ~ n1)
  => i -> IxR n i -> IxR n1 i
pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i))
  where i :.: IxR sh = IxR (i ::: sh)
infixr 3 :.:

{-# COMPLETE ZIR, (:.:) #-}

type IIxR n = IxR n Int

instance Show i => Show (IxR n i) where
  showsPrec _ (IxR l) = listrShow shows l

instance NFData i => NFData (IxR sh i)

ixrLength :: IxR sh i -> Int
ixrLength (IxR l) = listrLength l

ixrRank :: IxR n i -> SNat n
ixrRank (IxR sh) = listrRank sh

ixrZero :: SNat n -> IIxR n
ixrZero SZ = ZIR
ixrZero (SS n) = 0 :.: ixrZero n

ixCvtXR :: IIxX sh -> IIxR (Rank sh)
ixCvtXR ZIX = ZIR
ixCvtXR (n :.% idx) = n :.: ixCvtXR idx

ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
ixCvtRX ZIR = ZIX
ixCvtRX (n :.: (idx :: IxR m Int)) =
  castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m))
    (n :.% ixCvtRX idx)

ixrHead :: IxR (n + 1) i -> i
ixrHead (IxR list) = listrHead list

ixrTail :: IxR (n + 1) i -> IxR n i
ixrTail (IxR list) = IxR (listrTail list)

ixrInit :: IxR (n + 1) i -> IxR n i
ixrInit (IxR list) = IxR (listrInit list)

ixrLast :: IxR (n + 1) i -> i
ixrLast (IxR list) = listrLast list

ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i
ixrAppend = coerce (listrAppend @_ @i)

ixrZip :: IxR n i -> IxR n j -> IxR n (i, j)
ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2

ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k
ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2

ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
ixrPermutePrefix = coerce (listrPermutePrefix @i)


type role ShR nominal representational
type ShR :: Nat -> Type -> Type
newtype ShR n i = ShR (ListR n i)
  deriving (Eq, Ord, Generic)
  deriving newtype (Functor, Foldable)

pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
pattern ZSR = ShR ZR

pattern (:$:)
  :: forall {n1} {i}.
     forall n. (n + 1 ~ n1)
  => i -> ShR n i -> ShR n1 i
pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i))
  where i :$: ShR sh = ShR (i ::: sh)
infixr 3 :$:

{-# COMPLETE ZSR, (:$:) #-}

type IShR n = ShR n Int

instance Show i => Show (ShR n i) where
  showsPrec _ (ShR l) = listrShow shows l

instance NFData i => NFData (ShR sh i)

shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n
shCvtXR' ZSX =
  castWith (subst2 (unsafeCoerceRefl :: 0 :~: n))
    ZSR
shCvtXR' (n :$% (idx :: IShX sh))
  | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) =
  castWith (subst2 (lem1 @sh Refl))
    (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx))
  where
    lem1 :: forall sh' n' k.
            k : sh' :~: Replicate n' Nothing
         -> Rank sh' + 1 :~: n'
    lem1 Refl = unsafeCoerceRefl

    lem2 :: k : sh :~: Replicate n Nothing
         -> sh :~: Replicate (Rank sh) Nothing
    lem2 Refl = unsafeCoerceRefl

shCvtRX :: IShR n -> IShX (Replicate n Nothing)
shCvtRX ZSR = ZSX
shCvtRX (n :$: (idx :: ShR m Int)) =
  castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m))
    (SUnknown n :$% shCvtRX idx)

-- | This checks only whether the ranks are equal, not whether the actual
-- values are.
shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n')
shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh'

-- | This compares the shapes for value equality.
shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n')
shrEqual (ShR sh) (ShR sh') = listrEqual sh sh'

shrLength :: ShR sh i -> Int
shrLength (ShR l) = listrLength l

-- | This function can also be used to conjure up a 'KnownNat' dictionary;
-- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern
-- synonym yields 'KnownNat' evidence.
shrRank :: ShR n i -> SNat n
shrRank (ShR sh) = listrRank sh

-- | The number of elements in an array described by this shape.
shrSize :: IShR n -> Int
shrSize ZSR = 1
shrSize (n :$: sh) = n * shrSize sh

shrHead :: ShR (n + 1) i -> i
shrHead (ShR list) = listrHead list

shrTail :: ShR (n + 1) i -> ShR n i
shrTail (ShR list) = ShR (listrTail list)

shrInit :: ShR (n + 1) i -> ShR n i
shrInit (ShR list) = ShR (listrInit list)

shrLast :: ShR (n + 1) i -> i
shrLast (ShR list) = listrLast list

shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i
shrAppend = coerce (listrAppend @_ @i)

shrZip :: ShR n i -> ShR n j -> ShR n (i, j)
shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2

shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k
shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2

shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i
shrPermutePrefix = coerce (listrPermutePrefix @i)


-- | Untyped: length is checked at runtime.
instance KnownNat n => IsList (ListR n i) where
  type Item (ListR n i) = i
  fromList topl = go (SNat @n) topl
    where
      go :: SNat n' -> [i] -> ListR n' i
      go SZ [] = ZR
      go (SS n) (i : is) = i ::: go n is
      go _ _ = error $ "IsList(ListR): Mismatched list length (type says "
                         ++ show (fromSNat (SNat @n)) ++ ", list has length "
                         ++ show (length topl) ++ ")"
  toList = Foldable.toList

-- | Untyped: length is checked at runtime.
instance KnownNat n => IsList (IxR n i) where
  type Item (IxR n i) = i
  fromList = IxR . IsList.fromList
  toList = Foldable.toList

-- | Untyped: length is checked at runtime.
instance KnownNat n => IsList (ShR n i) where
  type Item (ShR n i) = i
  fromList = ShR . IsList.fromList
  toList = Foldable.toList