blob: ef9bb8d39b9e28fe05cc64a66c0c78a53e6349ba (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TupleSections #-}
module Array where
import Control.DeepSeq
import Control.Monad.Trans.State.Strict
import Data.Foldable (traverse_)
import Data.Vector (Vector)
import qualified Data.Vector as V
import GHC.Generics (Generic)
import Data
data Shape n where
ShNil :: Shape Z
ShCons :: Shape n -> Int -> Shape (S n)
deriving instance Show (Shape n)
deriving instance Eq (Shape n)
instance NFData (Shape n) where
rnf ShNil = ()
rnf (sh `ShCons` n) = rnf n `seq` rnf sh
data Index n where
IxNil :: Index Z
IxCons :: Index n -> Int -> Index (S n)
deriving instance Show (Index n)
deriving instance Eq (Index n)
instance NFData (Index n) where
rnf IxNil = ()
rnf (sh `IxCons` n) = rnf n `seq` rnf sh
shapeSize :: Shape n -> Int
shapeSize ShNil = 1
shapeSize (ShCons sh n) = shapeSize sh * n
fromLinearIndex :: Shape n -> Int -> Index n
fromLinearIndex ShNil 0 = IxNil
fromLinearIndex ShNil _ = error "Index out of range"
fromLinearIndex (sh `ShCons` n) i =
let (q, r) = i `quotRem` n
in fromLinearIndex sh q `IxCons` r
toLinearIndex :: Shape n -> Index n -> Int
toLinearIndex ShNil IxNil = 0
toLinearIndex (sh `ShCons` n) (idx `IxCons` i) = toLinearIndex sh idx * n + i
emptyShape :: SNat n -> Shape n
emptyShape SZ = ShNil
emptyShape (SS m) = emptyShape m `ShCons` 0
enumShape :: Shape n -> [Index n]
enumShape sh = map (fromLinearIndex sh) [0 .. shapeSize sh - 1]
-- | TODO: this Vector is a boxed vector, which is horrendously inefficient.
data Array (n :: Nat) t = Array (Shape n) (Vector t)
deriving (Show, Functor, Foldable, Traversable, Generic)
instance NFData t => NFData (Array n t)
arrayShape :: Array n t -> Shape n
arrayShape (Array sh _) = sh
arraySize :: Array n t -> Int
arraySize (Array sh _) = shapeSize sh
emptyArray :: SNat n -> Array n t
emptyArray n = Array (emptyShape n) V.empty
arrayFromList :: Shape n -> [t] -> Array n t
arrayFromList sh l = Array sh (V.fromListN (shapeSize sh) l)
arrayUnit :: t -> Array Z t
arrayUnit x = Array ShNil (V.singleton x)
arrayIndex :: Array n t -> Index n -> t
arrayIndex arr@(Array sh _) idx = arrayIndexLinear arr (toLinearIndex sh idx)
arrayIndexLinear :: Array n t -> Int -> t
arrayIndexLinear (Array _ v) i = v V.! i
arrayIndex1 :: Array (S n) t -> Int -> Array n t
arrayIndex1 (Array (sh `ShCons` _) v) i = let sz = shapeSize sh in Array sh (V.slice (sz * i) sz v)
arrayGenerate :: Shape n -> (Index n -> t) -> Array n t
arrayGenerate sh f = arrayGenerateLin sh (f . fromLinearIndex sh)
arrayGenerateLin :: Shape n -> (Int -> t) -> Array n t
arrayGenerateLin sh f = Array sh (V.generate (shapeSize sh) f)
arrayGenerateM :: Monad m => Shape n -> (Index n -> m t) -> m (Array n t)
arrayGenerateM sh f = arrayGenerateLinM sh (f . fromLinearIndex sh)
arrayGenerateLinM :: Monad m => Shape n -> (Int -> m t) -> m (Array n t)
arrayGenerateLinM sh f = Array sh <$> V.generateM (shapeSize sh) f
arrayMap :: (a -> b) -> Array n a -> Array n b
arrayMap f arr = arrayGenerateLin (arrayShape arr) (f . arrayIndexLinear arr)
arrayMapM :: Monad m => (a -> m b) -> Array n a -> m (Array n b)
arrayMapM f arr = arrayGenerateLinM (arrayShape arr) (f . arrayIndexLinear arr)
-- | The Int is the linear index of the value.
traverseArray_ :: Monad m => (Int -> t -> m ()) -> Array n t -> m ()
traverseArray_ f (Array _ v) = evalStateT (traverse_ (\x -> StateT (\i -> (,i+1) <$> f i x)) v) 0
|