aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed/Shape.hs
blob: 434357454f00039ffbd8fdda7730945502c7a979 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed.Shape where

import Control.DeepSeq (NFData(..))
import Data.Bifunctor (first)
import Data.Coerce
import Data.Foldable qualified as Foldable
import Data.Functor.Const
import Data.Kind (Type, Constraint)
import Data.Monoid (Sum(..))
import Data.Proxy
import Data.Type.Equality
import GHC.Generics (Generic)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits

import Data.Array.Mixed.Types


-- | The length of a type-level list. If the argument is a shape, then the
-- result is the rank of that shape.
type family Rank sh where
  Rank '[] = 0
  Rank (_ : sh) = Rank sh + 1


-- * Mixed lists

type role ListX nominal representational
type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type
data ListX sh f where
  ZX :: ListX '[] f
  (::%) :: f n -> ListX sh f -> ListX (n : sh) f
deriving instance (forall n. Eq (f n)) => Eq (ListX sh f)
deriving instance (forall n. Ord (f n)) => Ord (ListX sh f)
infixr 3 ::%

instance (forall n. Show (f n)) => Show (ListX sh f) where
  showsPrec _ = listxShow shows

instance (forall n. NFData (f n)) => NFData (ListX sh f) where
  rnf ZX = ()
  rnf (x ::% l) = rnf x `seq` rnf l

data UnconsListXRes f sh1 =
  forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n)
listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1)
listxUncons (i ::% shl') = Just (UnconsListXRes shl' i)
listxUncons ZX = Nothing

listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g
listxFmap _ ZX = ZX
listxFmap f (x ::% xs) = f x ::% listxFmap f xs

listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m
listxFold _ ZX = mempty
listxFold f (x ::% xs) = f x <> listxFold f xs

listxLength :: ListX sh f -> Int
listxLength = getSum . listxFold (\_ -> Sum 1)

listxLengthSNat :: ListX sh f -> SNat (Rank sh)
listxLengthSNat ZX = SNat
listxLengthSNat (_ ::% l) | SNat <- listxLengthSNat l = SNat

listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS
listxShow f l = showString "[" . go "" l . showString "]"
  where
    go :: String -> ListX sh' f -> ShowS
    go _ ZX = id
    go prefix (x ::% xs) = showString prefix . f x . go "," xs

listxToList :: ListX sh' (Const i) -> [i]
listxToList ZX = []
listxToList (Const i ::% is) = i : listxToList is

listxTail :: ListX (n : sh) i -> ListX sh i
listxTail (_ ::% sh) = sh

listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f
listxAppend ZX idx' = idx'
listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx'

listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f
listxDrop long ZX = long
listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short


-- * Mixed indices

-- | This is a newtype over 'ListX'.
type role IxX nominal representational
type IxX :: [Maybe Nat] -> Type -> Type
newtype IxX sh i = IxX (ListX sh (Const i))
  deriving (Eq, Ord, Generic)

pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i
pattern ZIX = IxX ZX

pattern (:.%)
  :: forall {sh1} {i}.
     forall n sh. (n : sh ~ sh1)
  => i -> IxX sh i -> IxX sh1 i
pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) (getConst -> i)))
  where i :.% IxX shl = IxX (Const i ::% shl)
infixr 3 :.%

{-# COMPLETE ZIX, (:.%) #-}

type IIxX sh = IxX sh Int

instance Show i => Show (IxX sh i) where
  showsPrec _ (IxX l) = listxShow (\(Const i) -> shows i) l

instance Functor (IxX sh) where
  fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l)

instance Foldable (IxX sh) where
  foldMap f (IxX l) = listxFold (f . getConst) l

instance NFData i => NFData (IxX sh i)

ixxZero :: StaticShX sh -> IIxX sh
ixxZero ZKX = ZIX
ixxZero (_ :!% ssh) = 0 :.% ixxZero ssh

ixxZero' :: IShX sh -> IIxX sh
ixxZero' ZSX = ZIX
ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh

ixxTail :: IxX (n : sh) i -> IxX sh i
ixxTail (IxX list) = IxX (listxTail list)

ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i
ixxAppend = coerce (listxAppend @_ @(Const i))

ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i
ixxDrop = coerce (listxDrop @(Const i) @(Const i))

ixxFromLinear :: IShX sh -> Int -> IIxX sh
ixxFromLinear = \sh i -> case go sh i of
  (idx, 0) -> idx
  _ -> error $ "ixxFromLinear: out of range (" ++ show i ++
               " in array of shape " ++ show sh ++ ")"
  where
    -- returns (index in subarray, remaining index in enclosing array)
    go :: IShX sh -> Int -> (IIxX sh, Int)
    go ZSX i = (ZIX, i)
    go (n :$% sh) i =
      let (idx, i') = go sh i
          (upi, locali) = i' `quotRem` fromSMayNat' n
      in (locali :.% idx, upi)

ixxToLinear :: IShX sh -> IIxX sh -> Int
ixxToLinear = \sh i -> fst (go sh i)
  where
    -- returns (index in subarray, size of subarray)
    go :: IShX sh -> IIxX sh -> (Int, Int)
    go ZSX ZIX = (0, 1)
    go (n :$% sh) (i :.% ix) =
      let (lidx, sz) = go sh ix
      in (sz * i + lidx, fromSMayNat' n * sz)


-- * Mixed shapes

data SMayNat i f n where
  SUnknown :: i -> SMayNat i f Nothing
  SKnown :: f n -> SMayNat i f (Just n)
deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n)
deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n)
deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n)

instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where
  rnf (SUnknown i) = rnf i
  rnf (SKnown x) = rnf x

fromSMayNat :: (n ~ Nothing => i -> r)
            -> (forall m. n ~ Just m => f m -> r)
            -> SMayNat i f n -> r
fromSMayNat f _ (SUnknown i) = f i
fromSMayNat _ g (SKnown s) = g s

fromSMayNat' :: SMayNat Int SNat n -> Int
fromSMayNat' = fromSMayNat id fromSNat'

type family AddMaybe n m where
  AddMaybe Nothing _ = Nothing
  AddMaybe (Just _) Nothing = Nothing
  AddMaybe (Just n) (Just m) = Just (n + m)

smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m)
smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m)
smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m)
smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m)


-- | This is a newtype over 'ListX'.
type role ShX nominal representational
type ShX :: [Maybe Nat] -> Type -> Type
newtype ShX sh i = ShX (ListX sh (SMayNat i SNat))
  deriving (Eq, Ord, Generic)

pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i
pattern ZSX = ShX ZX

pattern (:$%)
  :: forall {sh1} {i}.
     forall n sh. (n : sh ~ sh1)
  => SMayNat i SNat n -> ShX sh i -> ShX sh1 i
pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i))
  where i :$% ShX shl = ShX (i ::% shl)
infixr 3 :$%

{-# COMPLETE ZSX, (:$%) #-}

type IShX sh = ShX sh Int

instance Show i => Show (ShX sh i) where
  showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l

instance Functor (ShX sh) where
  fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l)

instance NFData i => NFData (ShX sh i) where
  rnf (ShX ZX) = ()
  rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l)
  rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l)

shxLength :: ShX sh i -> Int
shxLength (ShX l) = listxLength l

shxLengthSNat :: ShX sh f -> SNat (Rank sh)
shxLengthSNat (ShX list) = listxLengthSNat list

-- | This is more than @geq@: it also checks that the integers (the unknown
-- dimensions) are the same.
shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh')
shxEqual ZSX ZSX = Just Refl
shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh')
  | Just Refl <- sameNat n m
  , Just Refl <- shxEqual sh sh'
  = Just Refl
shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh')
  | i == j
  , Just Refl <- shxEqual sh sh'
  = Just Refl
shxEqual _ _ = Nothing

-- | The number of elements in an array described by this shape.
shxSize :: IShX sh -> Int
shxSize ZSX = 1
shxSize (n :$% sh) = fromSMayNat' n * shxSize sh

shxToList :: IShX sh -> [Int]
shxToList ZSX = []
shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh

shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i
shxAppend = coerce (listxAppend @_ @(SMayNat i SNat))

shxTail :: ShX (n : sh) i -> ShX sh i
shxTail (ShX list) = ShX (listxTail list)

shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i
shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat))

shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i
shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j))

shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i
shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat))

shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i
shxTakeSSX _ = flip go
  where
    go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i
    go ZKX _ = ZSX
    go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh

-- This is a weird operation, so it has a long name
shxCompleteZeros :: StaticShX sh -> IShX sh
shxCompleteZeros ZKX = ZSX
shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh
shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh

shxSplitApp :: Proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i)
shxSplitApp _ ZKX idx = (ZSX, idx)
shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx)

shxEnum :: IShX sh -> [IIxX sh]
shxEnum = \sh -> go sh id []
  where
    go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a]
    go ZSX f = (f ZIX :)
    go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]]

shxRank :: ShX sh f -> SNat (Rank sh)
shxRank ZSX = SNat
shxRank (_ :$% sh) | SNat <- shxRank sh = SNat


-- * Static mixed shapes

-- | The part of a shape that is statically known. (A newtype over 'ListX'.)
type StaticShX :: [Maybe Nat] -> Type
newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat))
  deriving (Eq, Ord)

pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh
pattern ZKX = StaticShX ZX

pattern (:!%)
  :: forall {sh1}.
     forall n sh. (n : sh ~ sh1)
  => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1
pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i))
  where i :!% StaticShX shl = StaticShX (i ::% shl)
infixr 3 :!%

{-# COMPLETE ZKX, (:!%) #-}

instance Show (StaticShX sh) where
  showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l

ssxLength :: StaticShX sh -> Int
ssxLength (StaticShX l) = listxLength l

-- | This suffices as an implementation of @geq@ in the @Data.GADT.Compare@
-- class of the @some@ package.
ssxGeq :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh')
ssxGeq ZKX ZKX = Just Refl
ssxGeq (SKnown n@SNat :!% sh) (SKnown m@SNat :!% sh')
  | Just Refl <- sameNat n m
  , Just Refl <- ssxGeq sh sh'
  = Just Refl
ssxGeq (SUnknown () :!% sh) (SUnknown () :!% sh')
  | Just Refl <- ssxGeq sh sh'
  = Just Refl
ssxGeq _ _ = Nothing

ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
ssxAppend ZKX sh' = sh'
ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh'

ssxTail :: StaticShX (n : sh) -> StaticShX sh
ssxTail (_ :!% ssh) = ssh

ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh'
ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i))

-- | This may fail if @sh@ has @Nothing@s in it.
ssxToShX' :: StaticShX sh -> Maybe (IShX sh)
ssxToShX' ZKX = Just ZSX
ssxToShX' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShX' sh
ssxToShX' (SUnknown _ :!% _) = Nothing

ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing)
ssxReplicate SZ = ZKX
ssxReplicate (SS (n :: SNat n'))
  | Refl <- lemReplicateSucc @(Nothing @Nat) @n'
  = SUnknown () :!% ssxReplicate n

ssxIotaFrom :: Int -> StaticShX sh -> [Int]
ssxIotaFrom _ ZKX = []
ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh

ssxFromShape :: IShX sh -> StaticShX sh
ssxFromShape ZSX = ZKX
ssxFromShape (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShape sh

ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing)
ssxFromSNat SZ = ZKX
ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n


-- | Evidence for the static part of a shape. This pops up only when you are
-- polymorphic in the element type of an array.
type KnownShX :: [Maybe Nat] -> Constraint
class KnownShX sh where knownShX :: StaticShX sh
instance KnownShX '[] where knownShX = ZKX
instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX
instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX


-- * Flattening

type Flatten sh = Flatten' 1 sh

type family Flatten' acc sh where
  Flatten' acc '[] = Just acc
  Flatten' acc (Nothing : sh) = Nothing
  Flatten' acc (Just n : sh) = Flatten' (acc * n) sh

-- This function is currently unused
ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh)
ssxFlatten = go (SNat @1)
  where
    go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh)
    go acc ZKX = SKnown acc
    go _ (SUnknown () :!% _) = SUnknown ()
    go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh

shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh)
shxFlatten = go (SNat @1)
  where
    go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh)
    go acc ZSX = SKnown acc
    go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh)
    go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh

    goUnknown :: Int -> IShX sh -> Int
    goUnknown acc ZSX = acc
    goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh
    goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh


-- | Very untyped: only length is checked (at runtime).
instance KnownShX sh => IsList (ListX sh (Const i)) where
  type Item (ListX sh (Const i)) = i
  fromList topl = go (knownShX @sh) topl
    where
      go :: StaticShX sh' -> [i] -> ListX sh' (Const i)
      go ZKX [] = ZX
      go (_ :!% sh) (i : is) = Const i ::% go sh is
      go _ _ = error $ "IsList(ListX): Mismatched list length (type says "
                         ++ show (ssxLength (knownShX @sh)) ++ ", list has length "
                         ++ show (length topl) ++ ")"
  toList = listxToList

-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
instance KnownShX sh => IsList (IxX sh i) where
  type Item (IxX sh i) = i
  fromList = IxX . IsList.fromList
  toList = Foldable.toList

-- | Untyped: length and known dimensions are checked (at runtime).
instance KnownShX sh => IsList (ShX sh Int) where
  type Item (ShX sh Int) = Int
  fromList topl = ShX (go (knownShX @sh) topl)
    where
      go :: StaticShX sh' -> [Int] -> ListX sh' (SMayNat Int SNat)
      go ZKX [] = ZX
      go (SKnown sn :!% sh) (i : is)
        | i == fromSNat' sn = SKnown sn ::% go sh is
        | otherwise = error $ "IsList(ShX): Value does not match typing (type says "
                                ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
      go (SUnknown () :!% sh) (i : is) = SUnknown i ::% go sh is
      go _ _ = error $ "IsList(ShX): Mismatched list length (type says "
                         ++ show (ssxLength (knownShX @sh)) ++ ", list has length "
                         ++ show (length topl) ++ ")"
  toList = shxToList