aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Ranked/Base.hs
blob: beb5b0e355b60fb95ea2be36761bfeae464f0a14 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_HADDOCK not-home #-}
module Data.Array.Nested.Ranked.Base where

import Prelude hiding (mappend, mconcat)

import Control.DeepSeq (NFData(..))
import Control.Monad.ST
import Data.Bifunctor (first)
import Data.Coerce (coerce)
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
import Foreign.Storable (Storable)
import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
import GHC.TypeLits

#ifndef OXAR_DEFAULT_SHOW_INSTANCES
import Data.Foldable (toList)
#endif

import Data.Array.Nested.Lemmas
import Data.Array.Nested.Types
import Data.Array.XArray (XArray(..))
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Ranked.Shape
import Data.Array.Strided.Arith


-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
-- represented on the type level as a 'Nat'.
--
-- Valid elements of a ranked arrays are described by the 'Elt' type class.
-- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are
-- supported (and are represented as a single, flattened, struct-of-arrays
-- array internally).
--
-- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's.
type Ranked :: Nat -> Type -> Type
newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a)
#endif
deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a)
deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a)

#ifndef OXAR_DEFAULT_SHOW_INSTANCES
instance (Show a, Elt a) => Show (Ranked n a) where
  showsPrec d arr@(Ranked marr) =
    let sh = show (toList (rshape arr))
    in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr
#endif

instance Elt a => NFData (Ranked n a) where
  rnf (Ranked arr) = rnf arr

-- just unwrap the newtype and defer to the general instance for nested arrays
newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a))
  deriving (Generic)
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
deriving instance Show (Mixed sh (Mixed (Replicate n Nothing) a)) => Show (Mixed sh (Ranked n a))
#endif

deriving instance Eq (Mixed sh (Mixed (Replicate n Nothing) a)) => Eq (Mixed sh (Ranked n a))

newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a))

-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest;
-- these instances allow them to also be used as elements of arrays, thus
-- making them first-class in the API.
instance Elt a => Elt (Ranked n a) where
  mshape (M_Ranked arr) = mshape arr
  mindex (M_Ranked arr) i = Ranked (mindex arr i)

  mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a)
  mindexPartial (M_Ranked arr) i =
    coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
        mindexPartial arr i

  mscalar (Ranked x) = M_Ranked (M_Nest ZSX x)

  mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a)
  mfromListOuter l = M_Ranked (mfromListOuter (coerce l))

  mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)]
  mtoListOuter (M_Ranked arr) =
    coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr)

  mlift :: forall sh1 sh2.
           StaticShX sh2
        -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
        -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a)
  mlift ssh2 f (M_Ranked arr) =
    coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $
      mlift ssh2 f arr

  mlift2 :: forall sh1 sh2 sh3.
            StaticShX sh3
         -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
         -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a)
  mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) =
    coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
      mlift2 ssh3 f arr1 arr2

  mliftL :: forall sh1 sh2.
            StaticShX sh2
         -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
         -> NonEmpty (Mixed sh1 (Ranked n a)) -> NonEmpty (Mixed sh2 (Ranked n a))
  mliftL ssh2 f l =
    coerce @(NonEmpty (Mixed sh2 (Mixed (Replicate n Nothing) a)))
           @(NonEmpty (Mixed sh2 (Ranked n a))) $
      mliftL ssh2 f (coerce l)

  mcastPartial ssh1 ssh2 psh' (M_Ranked arr) = M_Ranked (mcastPartial ssh1 ssh2 psh' arr)

  mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr)

  mconcat l = M_Ranked (mconcat (coerce l))

  mrnf (M_Ranked arr) = mrnf arr

  type ShapeTree (Ranked n a) = (IShR n, ShapeTree a)

  mshapeTree (Ranked arr) = first shrFromShX2 (mshapeTree arr)

  mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2

  mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t

  mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"

  marrayStrides (M_Ranked arr) = marrayStrides arr

  mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
  mvecsWrite sh idx (Ranked arr) vecs =
    mvecsWrite sh idx arr
      (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
         vecs)

  mvecsWritePartial :: forall sh sh' s.
                       IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
                    -> MixedVecs s (sh ++ sh') (Ranked n a)
                    -> ST s ()
  mvecsWritePartial sh idx arr vecs =
    mvecsWritePartial sh idx
      (coerce @(Mixed sh' (Ranked n a))
              @(Mixed sh' (Mixed (Replicate n Nothing) a))
         arr)
      (coerce @(MixedVecs s (sh ++ sh') (Ranked n a))
              @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
         vecs)

  mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
  mvecsFreeze sh vecs =
    coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
           @(Mixed sh (Ranked n a))
      <$> mvecsFreeze sh
            (coerce @(MixedVecs s sh (Ranked n a))
                    @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
                    vecs)

instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
  memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a)
  memptyArrayUnsafe i
    | Dict <- lemKnownReplicate (SNat @n)
    = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
        memptyArrayUnsafe i

  mvecsUnsafeNew idx (Ranked arr)
    | Dict <- lemKnownReplicate (SNat @n)
    = MV_Ranked <$> mvecsUnsafeNew idx arr

  mvecsNewEmpty _
    | Dict <- lemKnownReplicate (SNat @n)
    = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))


liftRanked1 :: forall n a b.
               (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b)
            -> Ranked n a -> Ranked n b
liftRanked1 = coerce

liftRanked2 :: forall n a b c.
               (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b -> Mixed (Replicate n Nothing) c)
            -> Ranked n a -> Ranked n b -> Ranked n c
liftRanked2 = coerce

instance (NumElt a, PrimElt a) => Num (Ranked n a) where
  (+) = liftRanked2 (+)
  (-) = liftRanked2 (-)
  (*) = liftRanked2 (*)
  negate = liftRanked1 negate
  abs = liftRanked1 abs
  signum = liftRanked1 signum
  fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicateScal"

instance (FloatElt a, PrimElt a) => Fractional (Ranked n a) where
  fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicateScal"
  recip = liftRanked1 recip
  (/) = liftRanked2 (/)

instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where
  pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicateScal"
  exp = liftRanked1 exp
  log = liftRanked1 log
  sqrt = liftRanked1 sqrt
  (**) = liftRanked2 (**)
  logBase = liftRanked2 logBase
  sin = liftRanked1 sin
  cos = liftRanked1 cos
  tan = liftRanked1 tan
  asin = liftRanked1 asin
  acos = liftRanked1 acos
  atan = liftRanked1 atan
  sinh = liftRanked1 sinh
  cosh = liftRanked1 cosh
  tanh = liftRanked1 tanh
  asinh = liftRanked1 asinh
  acosh = liftRanked1 acosh
  atanh = liftRanked1 atanh
  log1p = liftRanked1 GHC.Float.log1p
  expm1 = liftRanked1 GHC.Float.expm1
  log1pexp = liftRanked1 GHC.Float.log1pexp
  log1mexp = liftRanked1 GHC.Float.log1mexp

rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a
rquotArray = liftRanked2 mquotArray
rremArray = liftRanked2 mremArray

ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a
ratan2Array = liftRanked2 matan2Array


rshape :: Elt a => Ranked n a -> IShR n
rshape (Ranked arr) = shrFromShX2 (mshape arr)

rrank :: Elt a => Ranked n a -> SNat n
rrank = shrRank . rshape