diff options
Diffstat (limited to 'src/Array.hs')
-rw-r--r-- | src/Array.hs | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/src/Array.hs b/src/Array.hs new file mode 100644 index 0000000..9a770c4 --- /dev/null +++ b/src/Array.hs @@ -0,0 +1,69 @@ +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TupleSections #-} +module Array where + +import Control.Monad.Trans.State.Strict +import Data.Foldable (traverse_) +import Data.Vector (Vector) +import qualified Data.Vector as V + +import Data + + +data Shape n where + ShNil :: Shape Z + ShCons :: Shape n -> Int -> Shape (S n) +deriving instance Show (Shape n) + +data Index n where + IxNil :: Index Z + IxCons :: Index n -> Int -> Index (S n) +deriving instance Show (Index n) + +shapeSize :: Shape n -> Int +shapeSize ShNil = 0 +shapeSize (ShCons sh n) = shapeSize sh * n + +fromLinearIndex :: Shape n -> Int -> Index n +fromLinearIndex ShNil 0 = IxNil +fromLinearIndex ShNil _ = error "Index out of range" +fromLinearIndex (sh `ShCons` n) i = + let (q, r) = i `quotRem` n + in fromLinearIndex sh q `IxCons` r + +toLinearIndex :: Shape n -> Index n -> Int +toLinearIndex ShNil IxNil = 0 +toLinearIndex (sh `ShCons` n) (idx `IxCons` i) = toLinearIndex sh idx * n + i + + +-- | TODO: this Vector is a boxed vector, which is horrendously inefficient. +data Array (n :: Nat) t = Array (Shape n) (Vector t) + deriving (Show) + +arrayShape :: Array n t -> Shape n +arrayShape (Array sh _) = sh + +arraySize :: Array n t -> Int +arraySize (Array sh _) = shapeSize sh + +arrayIndex :: Array n t -> Index n -> t +arrayIndex arr@(Array sh _) idx = arrayIndexLinear arr (toLinearIndex sh idx) + +arrayIndexLinear :: Array n t -> Int -> t +arrayIndexLinear (Array _ v) i = v V.! i + +arrayIndex1 :: Array (S n) t -> Int -> Array n t +arrayIndex1 (Array (sh `ShCons` _) v) i = let sz = shapeSize sh in Array sh (V.slice (sz * i) sz v) + +arrayGenerateM :: Monad m => Shape n -> (Index n -> m t) -> m (Array n t) +arrayGenerateM sh f = arrayGenerateLinM sh (f . fromLinearIndex sh) + +arrayGenerateLinM :: Monad m => Shape n -> (Int -> m t) -> m (Array n t) +arrayGenerateLinM sh f = Array sh <$> V.generateM (shapeSize sh) f + +-- | The Int is the linear index of the value. +traverseArray_ :: Monad m => (Int -> t -> m ()) -> Array n t -> m () +traverseArray_ f (Array _ v) = evalStateT (traverse_ (\x -> StateT (\i -> (,i+1) <$> f i x)) v) 0 |