summaryrefslogtreecommitdiff
path: root/src/Array.hs
blob: 9a770c4327f147718862a4b226f5b0f94f6c2ac0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TupleSections #-}
module Array where

import Control.Monad.Trans.State.Strict
import Data.Foldable (traverse_)
import Data.Vector (Vector)
import qualified Data.Vector as V

import Data


data Shape n where
  ShNil :: Shape Z
  ShCons :: Shape n -> Int -> Shape (S n)
deriving instance Show (Shape n)

data Index n where
  IxNil :: Index Z
  IxCons :: Index n -> Int -> Index (S n)
deriving instance Show (Index n)

shapeSize :: Shape n -> Int
shapeSize ShNil = 0
shapeSize (ShCons sh n) = shapeSize sh * n

fromLinearIndex :: Shape n -> Int -> Index n
fromLinearIndex ShNil 0 = IxNil
fromLinearIndex ShNil _ = error "Index out of range"
fromLinearIndex (sh `ShCons` n) i =
  let (q, r) = i `quotRem` n
  in fromLinearIndex sh q `IxCons` r

toLinearIndex :: Shape n -> Index n -> Int
toLinearIndex ShNil IxNil = 0
toLinearIndex (sh `ShCons` n) (idx `IxCons` i) = toLinearIndex sh idx * n + i


-- | TODO: this Vector is a boxed vector, which is horrendously inefficient.
data Array (n :: Nat) t = Array (Shape n) (Vector t)
  deriving (Show)

arrayShape :: Array n t -> Shape n
arrayShape (Array sh _) = sh

arraySize :: Array n t -> Int
arraySize (Array sh _) = shapeSize sh

arrayIndex :: Array n t -> Index n -> t
arrayIndex arr@(Array sh _) idx = arrayIndexLinear arr (toLinearIndex sh idx)

arrayIndexLinear :: Array n t -> Int -> t
arrayIndexLinear (Array _ v) i = v V.! i

arrayIndex1 :: Array (S n) t -> Int -> Array n t
arrayIndex1 (Array (sh `ShCons` _) v) i = let sz = shapeSize sh in Array sh (V.slice (sz * i) sz v)

arrayGenerateM :: Monad m => Shape n -> (Index n -> m t) -> m (Array n t)
arrayGenerateM sh f = arrayGenerateLinM sh (f . fromLinearIndex sh)

arrayGenerateLinM :: Monad m => Shape n -> (Int -> m t) -> m (Array n t)
arrayGenerateLinM sh f = Array sh <$> V.generateM (shapeSize sh) f

-- | The Int is the linear index of the value.
traverseArray_ :: Monad m => (Int -> t -> m ()) -> Array n t -> m ()
traverseArray_ f (Array _ v) = evalStateT (traverse_ (\x -> StateT (\i -> (,i+1) <$> f i x)) v) 0