diff options
Diffstat (limited to 'src/Array.hs')
-rw-r--r-- | src/Array.hs | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/src/Array.hs b/src/Array.hs index d7dadbf..6473bf0 100644 --- a/src/Array.hs +++ b/src/Array.hs @@ -59,6 +59,9 @@ arraySize (Array sh _) = shapeSize sh emptyArray :: SNat n -> Array n t emptyArray n = Array (emptyShape n) V.empty +arrayUnit :: t -> Array Z t +arrayUnit x = Array ShNil (V.singleton x) + arrayIndex :: Array n t -> Index n -> t arrayIndex arr@(Array sh _) idx = arrayIndexLinear arr (toLinearIndex sh idx) @@ -80,6 +83,12 @@ arrayGenerateM sh f = arrayGenerateLinM sh (f . fromLinearIndex sh) arrayGenerateLinM :: Monad m => Shape n -> (Int -> m t) -> m (Array n t) arrayGenerateLinM sh f = Array sh <$> V.generateM (shapeSize sh) f +arrayMap :: (a -> b) -> Array n a -> Array n b +arrayMap f arr = arrayGenerateLin (arrayShape arr) (f . arrayIndexLinear arr) + +arrayMapM :: Monad m => (a -> m b) -> Array n a -> m (Array n b) +arrayMapM f arr = arrayGenerateLinM (arrayShape arr) (f . arrayIndexLinear arr) + -- | The Int is the linear index of the value. traverseArray_ :: Monad m => (Int -> t -> m ()) -> Array n t -> m () traverseArray_ f (Array _ v) = evalStateT (traverse_ (\x -> StateT (\i -> (,i+1) <$> f i x)) v) 0 |