aboutsummaryrefslogtreecommitdiff
path: root/src/Numeric/ADDual/VectorOps.hs
blob: 9bedebed8a81e28c255fddff40b70a8cd89d76f8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Numeric.ADDual.VectorOps where

import Data.Kind (Type)
import qualified Data.Vector as V
import qualified Data.Vector.Strict as VSr
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Unboxed as VU
import Foreign.Storable (Storable)


class VectorOps v where
  type VectorOpsScalar v :: Type
  vfromListN :: Int -> [VectorOpsScalar v] -> v
  vfromList :: [VectorOpsScalar v] -> v
  vtoList :: v -> [VectorOpsScalar v]
  vlength :: v -> Int
  vreplicate :: Int -> VectorOpsScalar v -> v
  vselect :: VS.Vector Bool -> v -> v -> v  -- ^ True selects the first argument, False the second

class (VectorOps v, Num (VectorOpsScalar v)) => VectorOpsNum v where
  vadd :: v -> v -> v
  vsub :: v -> v -> v
  vmul :: v -> v -> v
  vsum :: v -> VectorOpsScalar v

class (VectorOpsNum v, Floating (VectorOpsScalar v)) => VectorOpsFloating v where
  vexp :: v -> v

class (VectorOps v, Ord (VectorOpsScalar v)) => VectorOpsOrd v where
  vcmpLE :: v -> v -> VS.Vector Bool
  vmaximum :: v -> VectorOpsScalar v

  vcmpLT, vcmpGT, vcmpGE :: v -> v -> VS.Vector Bool
  vcmpLT a b = VS.map not (vcmpLE b a)
  vcmpGT a b = VS.map not (vcmpLE a b)
  vcmpGE a b = vcmpLE b a

instance VectorOps (V.Vector a) where
  type VectorOpsScalar (V.Vector a) = a
  vfromListN = V.fromListN
  vfromList = V.fromList
  vtoList = V.toList
  vlength = V.length
  vreplicate = V.replicate
  vselect bs a b = V.fromListN (VS.length bs) [if bs VS.! i then a V.! i else b V.! i
                                              | i <- [0 .. VS.length bs - 1]]

instance Num a => VectorOpsNum (V.Vector a) where
  vadd = V.zipWith (+)
  vsub = V.zipWith (-)
  vmul = V.zipWith (*)
  vsum = V.sum

instance Floating a => VectorOpsFloating (V.Vector a) where
  vexp = V.map exp

instance Ord a => VectorOpsOrd (V.Vector a) where
  vcmpLE a b = VS.generate (V.length a) (\i -> a V.! i <= b V.! i)
  vmaximum = V.maximum

instance VectorOps (VSr.Vector a) where
  type VectorOpsScalar (VSr.Vector a) = a
  vfromListN = VSr.fromListN
  vfromList = VSr.fromList
  vtoList = VSr.toList
  vlength = VSr.length
  vreplicate = VSr.replicate
  vselect bs a b = VSr.fromListN (VS.length bs) [if bs VS.! i then a VSr.! i else b VSr.! i
                                                | i <- [0 .. VS.length bs - 1]]

instance Num a => VectorOpsNum (VSr.Vector a) where
  vadd = VSr.zipWith (+)
  vsub = VSr.zipWith (-)
  vmul = VSr.zipWith (*)
  vsum = VSr.sum

instance Floating a => VectorOpsFloating (VSr.Vector a) where
  vexp = VSr.map exp

instance Ord a => VectorOpsOrd (VSr.Vector a) where
  vcmpLE a b = VS.generate (VSr.length a) (\i -> a VSr.! i <= b VSr.! i)
  vmaximum = VSr.maximum

instance Storable a => VectorOps (VS.Vector a) where
  type VectorOpsScalar (VS.Vector a) = a
  vfromListN = VS.fromListN
  vfromList = VS.fromList
  vtoList = VS.toList
  vlength = VS.length
  vreplicate = VS.replicate
  vselect bs a b = VS.fromListN (VS.length bs) [if bs VS.! i then a VS.! i else b VS.! i
                                               | i <- [0 .. VS.length bs - 1]]

instance (Storable a, Num a) => VectorOpsNum (VS.Vector a) where
  vadd = VS.zipWith (+)
  vsub = VS.zipWith (-)
  vmul = VS.zipWith (*)
  vsum = VS.sum

instance (Storable a, Floating a) => VectorOpsFloating (VS.Vector a) where
  vexp = VS.map exp

instance (Storable a, Ord a) => VectorOpsOrd (VS.Vector a) where
  vcmpLE a b = VS.generate (VS.length a) (\i -> a VS.! i <= b VS.! i)
  vmaximum = VS.maximum

instance VU.Unbox a => VectorOps (VU.Vector a) where
  type VectorOpsScalar (VU.Vector a) = a
  vfromListN = VU.fromListN
  vfromList = VU.fromList
  vtoList = VU.toList
  vlength = VU.length
  vreplicate = VU.replicate
  vselect bs a b = VU.fromListN (VS.length bs) [if bs VS.! i then a VU.! i else b VU.! i
                                               | i <- [0 .. VS.length bs - 1]]

instance (VU.Unbox a, Num a) => VectorOpsNum (VU.Vector a) where
  vadd = VU.zipWith (+)
  vsub = VU.zipWith (-)
  vmul = VU.zipWith (*)
  vsum = VU.sum

instance (VU.Unbox a, Floating a) => VectorOpsFloating (VU.Vector a) where
  vexp = VU.map exp

instance (VU.Unbox a, Ord a) => VectorOpsOrd (VU.Vector a) where
  vcmpLE a b = VS.generate (VU.length a) (\i -> a VU.! i <= b VU.! i)
  vmaximum = VU.maximum