aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/Array.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-10 21:49:45 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-10 21:50:25 +0100
commit174af2ba568de66e0d890825b8bda930b8e7bb96 (patch)
tree5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/CHAD/Array.hs
parent92bca235e3aaa287286b6af082d3fce585825a35 (diff)
Move module hierarchy under CHAD.
Diffstat (limited to 'src/CHAD/Array.hs')
-rw-r--r--src/CHAD/Array.hs131
1 files changed, 131 insertions, 0 deletions
diff --git a/src/CHAD/Array.hs b/src/CHAD/Array.hs
new file mode 100644
index 0000000..f80f961
--- /dev/null
+++ b/src/CHAD/Array.hs
@@ -0,0 +1,131 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE DeriveTraversable #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE KindSignatures #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TupleSections #-}
+module CHAD.Array where
+
+import Control.DeepSeq
+import Control.Monad.Trans.State.Strict
+import Data.Foldable (traverse_)
+import Data.Vector (Vector)
+import qualified Data.Vector as V
+import GHC.Generics (Generic)
+
+import CHAD.Data
+
+
+data Shape n where
+ ShNil :: Shape Z
+ ShCons :: Shape n -> Int -> Shape (S n)
+deriving instance Show (Shape n)
+deriving instance Eq (Shape n)
+
+instance NFData (Shape n) where
+ rnf ShNil = ()
+ rnf (sh `ShCons` n) = rnf n `seq` rnf sh
+
+data Index n where
+ IxNil :: Index Z
+ IxCons :: Index n -> Int -> Index (S n)
+deriving instance Show (Index n)
+deriving instance Eq (Index n)
+
+instance NFData (Index n) where
+ rnf IxNil = ()
+ rnf (sh `IxCons` n) = rnf n `seq` rnf sh
+
+shapeSize :: Shape n -> Int
+shapeSize ShNil = 1
+shapeSize (ShCons sh n) = shapeSize sh * n
+
+shapeRank :: Shape n -> SNat n
+shapeRank ShNil = SZ
+shapeRank (sh `ShCons` _) = SS (shapeRank sh)
+
+fromLinearIndex :: Shape n -> Int -> Index n
+fromLinearIndex ShNil 0 = IxNil
+fromLinearIndex ShNil _ = error "Index out of range"
+fromLinearIndex (sh `ShCons` n) i =
+ let (q, r) = i `quotRem` n
+ in fromLinearIndex sh q `IxCons` r
+
+toLinearIndex :: Shape n -> Index n -> Int
+toLinearIndex ShNil IxNil = 0
+toLinearIndex (sh `ShCons` n) (idx `IxCons` i) = toLinearIndex sh idx * n + i
+
+emptyShape :: SNat n -> Shape n
+emptyShape SZ = ShNil
+emptyShape (SS m) = emptyShape m `ShCons` 0
+
+enumShape :: Shape n -> [Index n]
+enumShape sh = map (fromLinearIndex sh) [0 .. shapeSize sh - 1]
+
+shapeToList :: Shape n -> [Int]
+shapeToList = go []
+ where
+ go :: [Int] -> Shape n -> [Int]
+ go suff ShNil = suff
+ go suff (sh `ShCons` n) = go (n:suff) sh
+
+
+-- | TODO: this Vector is a boxed vector, which is horrendously inefficient.
+data Array (n :: Nat) t = Array (Shape n) (Vector t)
+ deriving (Show, Functor, Foldable, Traversable, Generic)
+instance NFData t => NFData (Array n t)
+
+arrayShape :: Array n t -> Shape n
+arrayShape (Array sh _) = sh
+
+arraySize :: Array n t -> Int
+arraySize (Array sh _) = shapeSize sh
+
+emptyArray :: SNat n -> Array n t
+emptyArray n = Array (emptyShape n) V.empty
+
+arrayFromList :: Shape n -> [t] -> Array n t
+arrayFromList sh l = Array sh (V.fromListN (shapeSize sh) l)
+
+arrayToList :: Array n t -> [t]
+arrayToList (Array _ v) = V.toList v
+
+arrayReshape :: Shape n -> Array m t -> Array n t
+arrayReshape sh (Array sh' v)
+ | shapeSize sh == shapeSize sh' = Array sh v
+ | otherwise = error $ "arrayReshape: different shape size than original (" ++ show sh' ++ " -> " ++ show sh ++ ")"
+
+arrayUnit :: t -> Array Z t
+arrayUnit x = Array ShNil (V.singleton x)
+
+arrayIndex :: Array n t -> Index n -> t
+arrayIndex arr@(Array sh _) idx = arrayIndexLinear arr (toLinearIndex sh idx)
+
+arrayIndexLinear :: Array n t -> Int -> t
+arrayIndexLinear (Array _ v) i = v V.! i
+
+arrayIndex1 :: Array (S n) t -> Int -> Array n t
+arrayIndex1 (Array (sh `ShCons` _) v) i = let sz = shapeSize sh in Array sh (V.slice (sz * i) sz v)
+
+arrayGenerate :: Shape n -> (Index n -> t) -> Array n t
+arrayGenerate sh f = arrayGenerateLin sh (f . fromLinearIndex sh)
+
+arrayGenerateLin :: Shape n -> (Int -> t) -> Array n t
+arrayGenerateLin sh f = Array sh (V.generate (shapeSize sh) f)
+
+arrayGenerateM :: Monad m => Shape n -> (Index n -> m t) -> m (Array n t)
+arrayGenerateM sh f = arrayGenerateLinM sh (f . fromLinearIndex sh)
+
+arrayGenerateLinM :: Monad m => Shape n -> (Int -> m t) -> m (Array n t)
+arrayGenerateLinM sh f = Array sh <$> V.generateM (shapeSize sh) f
+
+arrayMap :: (a -> b) -> Array n a -> Array n b
+arrayMap f arr = arrayGenerateLin (arrayShape arr) (f . arrayIndexLinear arr)
+
+arrayMapM :: Monad m => (a -> m b) -> Array n a -> m (Array n b)
+arrayMapM f arr = arrayGenerateLinM (arrayShape arr) (f . arrayIndexLinear arr)
+
+-- | The Int is the linear index of the value.
+traverseArray_ :: Monad m => (Int -> t -> m ()) -> Array n t -> m ()
+traverseArray_ f (Array _ v) = evalStateT (traverse_ (\x -> StateT (\i -> (,i+1) <$> f i x)) v) 0