aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Shaped/Base.hs
blob: 74c231d25f852f606f6a6413ad9c225e4de3223f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_HADDOCK not-home #-}
module Data.Array.Nested.Shaped.Base where

import Prelude hiding (mappend, mconcat)

import Control.DeepSeq (NFData(..))
import Control.Monad.ST
import Data.Bifunctor (first)
import Data.Coerce (coerce)
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
import Foreign.Storable (Storable)
import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
import GHC.TypeLits

import Data.Array.Mixed.Types
import Data.Array.XArray (XArray)
import Data.Array.Nested.Internal.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Shaped.Shape
import Data.Array.Strided.Arith


-- | A shape-typed array: the full shape of the array (the sizes of its
-- dimensions) is represented on the type level as a list of 'Nat's. Note that
-- these are "GHC.TypeLits" naturals, because we do not need induction over
-- them and we want very large arrays to be possible.
--
-- Like for 'Ranked', the valid elements are described by the 'Elt' type class,
-- and 'Shaped' itself is again an instance of 'Elt' as well.
--
-- 'Shaped' is a newtype around a 'Mixed' of 'Just's.
type Shaped :: [Nat] -> Type -> Type
newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a)
deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a)

instance (Show a, Elt a) => Show (Shaped n a) where
  showsPrec d arr@(Shaped marr) =
    let sh = show (shsToList (sshape arr))
    in showsMixedArray ("sfromListLinear " ++ sh) ("sreplicate " ++ sh) d marr

instance Elt a => NFData (Shaped sh a) where
  rnf (Shaped arr) = rnf arr

-- just unwrap the newtype and defer to the general instance for nested arrays
newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a))
  deriving (Generic)

deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped sh' a))

newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a))

instance Elt a => Elt (Shaped sh a) where
  mshape (M_Shaped arr) = mshape arr
  mindex (M_Shaped arr) i = Shaped (mindex arr i)

  mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
  mindexPartial (M_Shaped arr) i =
    coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
      mindexPartial arr i

  mscalar (Shaped x) = M_Shaped (M_Nest ZSX x)

  mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a)
  mfromListOuter l = M_Shaped (mfromListOuter (coerce l))

  mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
  mtoListOuter (M_Shaped arr)
    = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr)

  mlift :: forall sh1 sh2.
           StaticShX sh2
        -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
        -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a)
  mlift ssh2 f (M_Shaped arr) =
    coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
      mlift ssh2 f arr

  mlift2 :: forall sh1 sh2 sh3.
            StaticShX sh3
         -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
         -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a)
  mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) =
    coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
      mlift2 ssh3 f arr1 arr2

  mliftL :: forall sh1 sh2.
            StaticShX sh2
         -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
         -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a))
  mliftL ssh2 f l =
    coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a)))
           @(NonEmpty (Mixed sh2 (Shaped sh a))) $
      mliftL ssh2 f (coerce l)

  mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr)

  mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr)

  mconcat l = M_Shaped (mconcat (coerce l))

  mrnf (M_Shaped arr) = mrnf arr

  type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)

  mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr)

  mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2

  mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t

  mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"

  marrayStrides (M_Shaped arr) = marrayStrides arr

  mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
  mvecsWrite sh idx (Shaped arr) vecs =
    mvecsWrite sh idx arr
      (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
         vecs)

  mvecsWritePartial :: forall sh1 sh2 s.
                       IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
                    -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
                    -> ST s ()
  mvecsWritePartial sh idx arr vecs =
    mvecsWritePartial sh idx
      (coerce @(Mixed sh2 (Shaped sh a))
              @(Mixed sh2 (Mixed (MapJust sh) a))
         arr)
      (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a))
              @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
         vecs)

  mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
  mvecsFreeze sh vecs =
    coerce @(Mixed sh' (Mixed (MapJust sh) a))
           @(Mixed sh' (Shaped sh a))
      <$> mvecsFreeze sh
            (coerce @(MixedVecs s sh' (Shaped sh a))
                    @(MixedVecs s sh' (Mixed (MapJust sh) a))
                    vecs)

instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
  memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
  memptyArrayUnsafe i
    | Dict <- lemKnownMapJust (Proxy @sh)
    = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
        memptyArrayUnsafe i

  mvecsUnsafeNew idx (Shaped arr)
    | Dict <- lemKnownMapJust (Proxy @sh)
    = MV_Shaped <$> mvecsUnsafeNew idx arr

  mvecsNewEmpty _
    | Dict <- lemKnownMapJust (Proxy @sh)
    = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))


liftShaped1 :: forall sh a b.
               (Mixed (MapJust sh) a -> Mixed (MapJust sh) b)
            -> Shaped sh a -> Shaped sh b
liftShaped1 = coerce

liftShaped2 :: forall sh a b c.
               (Mixed (MapJust sh) a -> Mixed (MapJust sh) b -> Mixed (MapJust sh) c)
            -> Shaped sh a -> Shaped sh b -> Shaped sh c
liftShaped2 = coerce

instance (NumElt a, PrimElt a) => Num (Shaped sh a) where
  (+) = liftShaped2 (+)
  (-) = liftShaped2 (-)
  (*) = liftShaped2 (*)
  negate = liftShaped1 negate
  abs = liftShaped1 abs
  signum = liftShaped1 signum
  fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal"

instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where
  fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal"
  recip = liftShaped1 recip
  (/) = liftShaped2 (/)

instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where
  pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal"
  exp = liftShaped1 exp
  log = liftShaped1 log
  sqrt = liftShaped1 sqrt
  (**) = liftShaped2 (**)
  logBase = liftShaped2 logBase
  sin = liftShaped1 sin
  cos = liftShaped1 cos
  tan = liftShaped1 tan
  asin = liftShaped1 asin
  acos = liftShaped1 acos
  atan = liftShaped1 atan
  sinh = liftShaped1 sinh
  cosh = liftShaped1 cosh
  tanh = liftShaped1 tanh
  asinh = liftShaped1 asinh
  acosh = liftShaped1 acosh
  atanh = liftShaped1 atanh
  log1p = liftShaped1 GHC.Float.log1p
  expm1 = liftShaped1 GHC.Float.expm1
  log1pexp = liftShaped1 GHC.Float.log1pexp
  log1mexp = liftShaped1 GHC.Float.log1mexp

squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a
squotArray = liftShaped2 mquotArray
sremArray = liftShaped2 mremArray

satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a
satan2Array = liftShaped2 matan2Array


sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
sshape (Shaped arr) = shCvtXS' (mshape arr)