aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed/XArray.hs
blob: 999ccc36ca84a76ba5530cfe50c7ed50adaab86b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed.XArray where

import Control.DeepSeq (NFData(..))
import Data.Array.Ranked qualified as ORB
import Data.Array.RankedS qualified as S
import Data.Coerce
import Data.Foldable (toList)
import Data.Kind
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
import Data.Type.Equality
import Data.Type.Ord
import Data.Vector.Storable qualified as VS
import Foreign.Storable (Storable)
import GHC.Generics (Generic)
import GHC.TypeLits

import Data.Array.Mixed.Internal.Arith
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Permutation
import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types


type XArray :: [Maybe Nat] -> Type -> Type
newtype XArray sh a = XArray (S.Array (Rank sh) a)
  deriving (Show, Eq, Generic)

-- | Only on scalars, because lexicographical ordering is strange on multi-dimensional arrays.
deriving instance (Ord a, Storable a) => Ord (XArray sh a)

instance NFData a => NFData (XArray sh a)


shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh
shape = \ssh (XArray arr) -> go ssh (S.shapeL arr)
  where
    go :: StaticShX sh' -> [Int] -> IShX sh'
    go ZKX [] = ZSX
    go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l
    go _ _ = error "Invalid shapeL"

fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a
fromVector sh v
  | Dict <- lemKnownNatRank sh
  = XArray (S.fromVector (shxToList sh) v)

toVector :: Storable a => XArray sh a -> VS.Vector a
toVector (XArray arr) = S.toVector arr

scalar :: Storable a => a -> XArray '[] a
scalar = XArray . S.scalar

-- | Will throw if the array does not have the casted-to shape.
cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2
     => StaticShX sh1 -> IShX sh2 -> StaticShX sh'
     -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a
cast ssh1 sh2 ssh' (XArray arr)
  | Refl <- lemRankApp ssh1 ssh'
  , Refl <- lemRankApp (ssxFromShape sh2) ssh'
  = let arrsh :: IShX sh1
        (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr))
    in if shxToList arrsh == shxToList sh2
         then XArray arr
         else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")"

unScalar :: Storable a => XArray '[] a -> a
unScalar (XArray a) = S.unScalar a

replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a
replicate sh ssh' (XArray arr)
  | Dict <- lemKnownNatRankSSX ssh'
  , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh) ssh')
  , Refl <- lemRankApp (ssxFromShape sh) ssh'
  = XArray (S.stretch (shxToList sh ++ S.shapeL arr) $
            S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr) $
              arr)

replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a
replicateScal sh x
  | Dict <- lemKnownNatRank sh
  = XArray (S.constant (shxToList sh) x)

generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a
generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh)

-- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a)
-- generateM sh f | Dict <- lemKnownNatRank sh =
--   XArray . S.fromVector (shxShapeL sh)
--     <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh)

indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a
indexPartial (XArray arr) ZIX = XArray arr
indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx

index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a
index xarr i
  | Refl <- lemAppNil @sh
  = let XArray arr' = indexPartial xarr i :: XArray '[] a
    in S.unScalar arr'

append :: forall n m sh a. Storable a
       => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a
append ssh (XArray a) (XArray b)
  | Dict <- lemKnownNatRankSSX ssh
  = XArray (S.append a b)

-- | All arrays must have the same shape, except possibly for the outermost
-- dimension.
concat :: Storable a
       => StaticShX sh -> NonEmpty (XArray (Nothing : sh) a) -> XArray (Nothing : sh) a
concat ssh l
  | Dict <- lemKnownNatRankSSX ssh
  = XArray (S.concatOuter (coerce (toList l)))

-- | If the prefix of the shape of the input array (@sh@) is empty (i.e.
-- contains a zero), then there is no way to deduce the full shape of the output
-- array (more precisely, the @sh2@ part): that could only come from calling
-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in
-- this case; we choose to fill the shape with zeros wherever we cannot deduce
-- what it should be.
--
-- For example, if:
--
-- @
-- arr :: XArray '[Just 3, Just 0, Just 4, Just 2, Nothing] Int   -- of shape [3, 0, 4, 2, 21]
-- f :: XArray '[Just 2, Nothing] Int -> XArray '[Just 5, Nothing, Just 17] Float
-- @
--
-- then:
--
-- @
-- rerank _ _ _ f arr :: XArray '[Just 3, Just 0, Just 4, Just 5, Nothing, Just 17] Float
-- @
--
-- and this result will have shape @[3, 0, 4, 5, 0, 17]@. Note the second @0@
-- in this shape: we don't know if @f@ intended to return an array with shape 0
-- here (it probably didn't), but there is no better number to put here absent
-- a subarray of the input to pass to @f@.
--
-- In this particular case the fact that @sh@ is empty was evident from the
-- type-level information, but the same situation occurs when @sh@ consists of
-- @Nothing@s, and some of those happen to be zero at runtime.
rerank :: forall sh sh1 sh2 a b.
          (Storable a, Storable b)
       => StaticShX sh -> StaticShX sh1 -> StaticShX sh2
       -> (XArray sh1 a -> XArray sh2 b)
       -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b
rerank ssh ssh1 ssh2 f xarr@(XArray arr)
  | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)
  = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr)
    in if any (== 0) (shxToList sh)
         then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) [])
         else case () of
           () | Dict <- lemKnownNatRankSSX ssh
              , Dict <- lemKnownNatRankSSX ssh2
              , Refl <- lemRankApp ssh ssh1
              , Refl <- lemRankApp ssh ssh2
              -> XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2)
                          (\a -> let XArray r = f (XArray a) in r)
                          arr)

rerankTop :: forall sh1 sh2 sh a b.
             (Storable a, Storable b)
          => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh
          -> (XArray sh1 a -> XArray sh2 b)
          -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b
rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh

-- | The caveat about empty arrays at @rerank@ applies here too.
rerank2 :: forall sh sh1 sh2 a b c.
           (Storable a, Storable b, Storable c)
        => StaticShX sh -> StaticShX sh1 -> StaticShX sh2
        -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c)
        -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c
rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2)
  | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)
  = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1)
    in if any (== 0) (shxToList sh)
         then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) [])
         else case () of
           () | Dict <- lemKnownNatRankSSX ssh
              , Dict <- lemKnownNatRankSSX ssh2
              , Refl <- lemRankApp ssh ssh1
              , Refl <- lemRankApp ssh ssh2
              -> XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2)
                          (\a b -> let XArray r = f (XArray a) (XArray b) in r)
                          arr1 arr2)

-- | The list argument gives indices into the original dimension list.
transpose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh)
          => StaticShX sh
          -> Perm is
          -> XArray sh a
          -> XArray (PermutePrefix is sh) a
transpose ssh perm (XArray arr)
  | Dict <- lemKnownNatRankSSX ssh
  , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh)
  , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm
  , Refl <- lemRankDropLen ssh perm
  = XArray (S.transpose (permToList' perm) arr)

-- | The list argument gives indices into the original dimension list.
--
-- The permutation (the list) must have length <= @n@. If it is longer, this
-- function throws.
transposeUntyped :: forall n sh a.
                    SNat n -> StaticShX sh -> [Int]
                 -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a
transposeUntyped sn ssh perm (XArray arr)
  | length perm <= fromSNat' sn
  , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh)
  = XArray (S.transpose perm arr)
  | otherwise
  = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type"

transpose2 :: forall sh1 sh2 a.
              StaticShX sh1 -> StaticShX sh2
           -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a
transpose2 ssh1 ssh2 (XArray arr)
  | Refl <- lemRankApp ssh1 ssh2
  , Refl <- lemRankApp ssh2 ssh1
  , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2)
  , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1)
  , Refl <- lemRankAppComm ssh1 ssh2
  , let n1 = ssxLength ssh1
  = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr)

sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a
sumFull _ (XArray arr) =
  S.unScalar $
    numEltSum1Inner (SNat @0) $
      S.fromVector [product (S.shapeL arr)] $
        S.toVector arr

sumInner :: forall sh sh' a. (Storable a, NumElt a)
         => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a
sumInner ssh ssh' arr
  | Refl <- lemAppNil @sh
  = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
        sh'F = shxFlatten sh' :$% ZSX
        ssh'F = ssxFromShape sh'F

        go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a
        go (XArray arr')
          | Refl <- lemRankApp ssh ssh'F
          , let sn = listxRank (let StaticShX l = ssh in l)
          = XArray (numEltSum1Inner sn arr')

    in go $
       transpose2 ssh'F ssh $
       reshapePartial ssh' ssh sh'F $
       transpose2 ssh ssh' $
         arr

sumOuter :: forall sh sh' a. (Storable a, NumElt a)
         => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a
sumOuter ssh ssh' arr
  | Refl <- lemAppNil @sh
  = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
        shF = shxFlatten sh :$% ZSX
    in sumInner ssh' (ssxFromShape shF) $
       transpose2 (ssxFromShape shF) ssh' $
       reshapePartial ssh ssh' shF $
         arr

fromListOuter :: forall n sh a. Storable a
              => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
fromListOuter ssh l
  | Dict <- lemKnownNatRankSSX ssh
  = case ssh of
      SKnown m :!% _ | fromSNat' m /= length l ->
        error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++
                "does not match the type (" ++ show (fromSNat' m) ++ ")"
      _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l)))

toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a]
toListOuter (XArray arr) =
  case S.shapeL arr of
    0 : _ -> []
    _ -> coerce (ORB.toList (S.unravel arr))

fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a
fromList1 ssh l =
  let n = length l
  in case ssh of
       SKnown m :!% _ | fromSNat' m /= n ->
         error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++
                 "does not match the type (" ++ show (fromSNat' m) ++ ")"
       _ -> XArray (S.fromVector [n] (VS.fromListN n l))

toList1 :: Storable a => XArray '[n] a -> [a]
toList1 (XArray arr) = S.toList arr

-- | Throws if the given shape is not, in fact, empty.
empty :: forall sh a. Storable a => IShX sh -> XArray sh a
empty sh
  | Dict <- lemKnownNatRank sh
  = XArray (S.constant (shxToList sh)
                       (error $ "Data.Array.Mixed.empty: shape was not empty: " ++ show sh))

slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a
slice i n (XArray arr) = XArray (S.slice [(fromSNat' i, fromSNat' n)] arr)

sliceU :: Int -> Int -> XArray (Nothing : sh) a -> XArray (Nothing : sh) a
sliceU i n (XArray arr) = XArray (S.slice [(i, n)] arr)

rev1 :: XArray (n : sh) a -> XArray (n : sh) a
rev1 (XArray arr) = XArray (S.rev [0] arr)

-- | Throws if the given array and the target shape do not have the same number of elements.
reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a
reshape ssh1 sh2 (XArray arr)
  | Dict <- lemKnownNatRankSSX ssh1
  , Dict <- lemKnownNatRank sh2
  = XArray (S.reshape (shxToList sh2) arr)

-- | Throws if the given array and the target shape do not have the same number of elements.
reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a
reshapePartial ssh1 ssh' sh2 (XArray arr)
  | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh')
  , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh2) ssh')
  = XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr)

-- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo).
iota :: (Enum a, Storable a) => SNat n -> XArray '[Just n] a
iota sn = XArray (S.fromVector [fromSNat' sn] (VS.fromListN (fromSNat' sn) [toEnum 0 .. toEnum (fromSNat' sn - 1)]))