blob: f358225e29df2635251b076e5d77eb9c28de214c (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
|
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TupleSections #-}
module Interpreter.Array where
import Control.Monad.Trans.State.Strict
import Data.Foldable (traverse_)
import Data.Vector (Vector)
import qualified Data.Vector as V
import Data
data Shape n where
ShNil :: Shape Z
ShCons :: Shape n -> Int -> Shape (S n)
data Index n where
IxNil :: Index Z
IxCons :: Index n -> Int -> Index (S n)
shapeSize :: Shape n -> Int
shapeSize ShNil = 0
shapeSize (ShCons sh n) = shapeSize sh * n
-- | TODO: this Vector is a boxed vector, which is horrendously inefficient.
data Array (n :: Nat) t = Array (Shape n) (Vector t)
arrayShape :: Array n t -> Shape n
arrayShape (Array sh _) = sh
arraySize :: Array n t -> Int
arraySize (Array sh _) = shapeSize sh
arrayGenerateLinM :: Monad m => Shape n -> (Int -> m t) -> m (Array n t)
arrayGenerateLinM sh f = Array sh <$> V.generateM (shapeSize sh) f
-- | The Int is the linear index of the value.
traverseArray_ :: Monad m => (Int -> t -> m ()) -> Array n t -> m ()
traverseArray_ f (Array _ v) = evalStateT (traverse_ (\x -> StateT (\i -> (,i+1) <$> f i x)) v) 0
|