summaryrefslogtreecommitdiff
path: root/src/Array.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Array.hs')
-rw-r--r--src/Array.hs9
1 files changed, 9 insertions, 0 deletions
diff --git a/src/Array.hs b/src/Array.hs
index d7dadbf..6473bf0 100644
--- a/src/Array.hs
+++ b/src/Array.hs
@@ -59,6 +59,9 @@ arraySize (Array sh _) = shapeSize sh
emptyArray :: SNat n -> Array n t
emptyArray n = Array (emptyShape n) V.empty
+arrayUnit :: t -> Array Z t
+arrayUnit x = Array ShNil (V.singleton x)
+
arrayIndex :: Array n t -> Index n -> t
arrayIndex arr@(Array sh _) idx = arrayIndexLinear arr (toLinearIndex sh idx)
@@ -80,6 +83,12 @@ arrayGenerateM sh f = arrayGenerateLinM sh (f . fromLinearIndex sh)
arrayGenerateLinM :: Monad m => Shape n -> (Int -> m t) -> m (Array n t)
arrayGenerateLinM sh f = Array sh <$> V.generateM (shapeSize sh) f
+arrayMap :: (a -> b) -> Array n a -> Array n b
+arrayMap f arr = arrayGenerateLin (arrayShape arr) (f . arrayIndexLinear arr)
+
+arrayMapM :: Monad m => (a -> m b) -> Array n a -> m (Array n b)
+arrayMapM f arr = arrayGenerateLinM (arrayShape arr) (f . arrayIndexLinear arr)
+
-- | The Int is the linear index of the value.
traverseArray_ :: Monad m => (Int -> t -> m ()) -> Array n t -> m ()
traverseArray_ f (Array _ v) = evalStateT (traverse_ (\x -> StateT (\i -> (,i+1) <$> f i x)) v) 0