{-# LANGUAGE KindSignatures #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE TupleSections #-} module Interpreter.Array where import Control.Monad.Trans.State.Strict import Data.Foldable (traverse_) import Data.Vector (Vector) import qualified Data.Vector as V import Data data Shape n where ShNil :: Shape Z ShCons :: Shape n -> Int -> Shape (S n) data Index n where IxNil :: Index Z IxCons :: Index n -> Int -> Index (S n) shapeSize :: Shape n -> Int shapeSize ShNil = 0 shapeSize (ShCons sh n) = shapeSize sh * n -- | TODO: this Vector is a boxed vector, which is horrendously inefficient. data Array (n :: Nat) t = Array (Shape n) (Vector t) arrayShape :: Array n t -> Shape n arrayShape (Array sh _) = sh arraySize :: Array n t -> Int arraySize (Array sh _) = shapeSize sh arrayIndexLinear :: Array n t -> Int -> t arrayIndexLinear (Array _ v) i = v V.! i arrayGenerateLinM :: Monad m => Shape n -> (Int -> m t) -> m (Array n t) arrayGenerateLinM sh f = Array sh <$> V.generateM (shapeSize sh) f -- | The Int is the linear index of the value. traverseArray_ :: Monad m => (Int -> t -> m ()) -> Array n t -> m () traverseArray_ f (Array _ v) = evalStateT (traverse_ (\x -> StateT (\i -> (,i+1) <$> f i x)) v) 0