diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:49:45 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:50:25 +0100 |
| commit | 174af2ba568de66e0d890825b8bda930b8e7bb96 (patch) | |
| tree | 5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/CHAD/Array.hs | |
| parent | 92bca235e3aaa287286b6af082d3fce585825a35 (diff) | |
Move module hierarchy under CHAD.
Diffstat (limited to 'src/CHAD/Array.hs')
| -rw-r--r-- | src/CHAD/Array.hs | 131 |
1 files changed, 131 insertions, 0 deletions
diff --git a/src/CHAD/Array.hs b/src/CHAD/Array.hs new file mode 100644 index 0000000..f80f961 --- /dev/null +++ b/src/CHAD/Array.hs @@ -0,0 +1,131 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TupleSections #-} +module CHAD.Array where + +import Control.DeepSeq +import Control.Monad.Trans.State.Strict +import Data.Foldable (traverse_) +import Data.Vector (Vector) +import qualified Data.Vector as V +import GHC.Generics (Generic) + +import CHAD.Data + + +data Shape n where + ShNil :: Shape Z + ShCons :: Shape n -> Int -> Shape (S n) +deriving instance Show (Shape n) +deriving instance Eq (Shape n) + +instance NFData (Shape n) where + rnf ShNil = () + rnf (sh `ShCons` n) = rnf n `seq` rnf sh + +data Index n where + IxNil :: Index Z + IxCons :: Index n -> Int -> Index (S n) +deriving instance Show (Index n) +deriving instance Eq (Index n) + +instance NFData (Index n) where + rnf IxNil = () + rnf (sh `IxCons` n) = rnf n `seq` rnf sh + +shapeSize :: Shape n -> Int +shapeSize ShNil = 1 +shapeSize (ShCons sh n) = shapeSize sh * n + +shapeRank :: Shape n -> SNat n +shapeRank ShNil = SZ +shapeRank (sh `ShCons` _) = SS (shapeRank sh) + +fromLinearIndex :: Shape n -> Int -> Index n +fromLinearIndex ShNil 0 = IxNil +fromLinearIndex ShNil _ = error "Index out of range" +fromLinearIndex (sh `ShCons` n) i = + let (q, r) = i `quotRem` n + in fromLinearIndex sh q `IxCons` r + +toLinearIndex :: Shape n -> Index n -> Int +toLinearIndex ShNil IxNil = 0 +toLinearIndex (sh `ShCons` n) (idx `IxCons` i) = toLinearIndex sh idx * n + i + +emptyShape :: SNat n -> Shape n +emptyShape SZ = ShNil +emptyShape (SS m) = emptyShape m `ShCons` 0 + +enumShape :: Shape n -> [Index n] +enumShape sh = map (fromLinearIndex sh) [0 .. shapeSize sh - 1] + +shapeToList :: Shape n -> [Int] +shapeToList = go [] + where + go :: [Int] -> Shape n -> [Int] + go suff ShNil = suff + go suff (sh `ShCons` n) = go (n:suff) sh + + +-- | TODO: this Vector is a boxed vector, which is horrendously inefficient. +data Array (n :: Nat) t = Array (Shape n) (Vector t) + deriving (Show, Functor, Foldable, Traversable, Generic) +instance NFData t => NFData (Array n t) + +arrayShape :: Array n t -> Shape n +arrayShape (Array sh _) = sh + +arraySize :: Array n t -> Int +arraySize (Array sh _) = shapeSize sh + +emptyArray :: SNat n -> Array n t +emptyArray n = Array (emptyShape n) V.empty + +arrayFromList :: Shape n -> [t] -> Array n t +arrayFromList sh l = Array sh (V.fromListN (shapeSize sh) l) + +arrayToList :: Array n t -> [t] +arrayToList (Array _ v) = V.toList v + +arrayReshape :: Shape n -> Array m t -> Array n t +arrayReshape sh (Array sh' v) + | shapeSize sh == shapeSize sh' = Array sh v + | otherwise = error $ "arrayReshape: different shape size than original (" ++ show sh' ++ " -> " ++ show sh ++ ")" + +arrayUnit :: t -> Array Z t +arrayUnit x = Array ShNil (V.singleton x) + +arrayIndex :: Array n t -> Index n -> t +arrayIndex arr@(Array sh _) idx = arrayIndexLinear arr (toLinearIndex sh idx) + +arrayIndexLinear :: Array n t -> Int -> t +arrayIndexLinear (Array _ v) i = v V.! i + +arrayIndex1 :: Array (S n) t -> Int -> Array n t +arrayIndex1 (Array (sh `ShCons` _) v) i = let sz = shapeSize sh in Array sh (V.slice (sz * i) sz v) + +arrayGenerate :: Shape n -> (Index n -> t) -> Array n t +arrayGenerate sh f = arrayGenerateLin sh (f . fromLinearIndex sh) + +arrayGenerateLin :: Shape n -> (Int -> t) -> Array n t +arrayGenerateLin sh f = Array sh (V.generate (shapeSize sh) f) + +arrayGenerateM :: Monad m => Shape n -> (Index n -> m t) -> m (Array n t) +arrayGenerateM sh f = arrayGenerateLinM sh (f . fromLinearIndex sh) + +arrayGenerateLinM :: Monad m => Shape n -> (Int -> m t) -> m (Array n t) +arrayGenerateLinM sh f = Array sh <$> V.generateM (shapeSize sh) f + +arrayMap :: (a -> b) -> Array n a -> Array n b +arrayMap f arr = arrayGenerateLin (arrayShape arr) (f . arrayIndexLinear arr) + +arrayMapM :: Monad m => (a -> m b) -> Array n a -> m (Array n b) +arrayMapM f arr = arrayGenerateLinM (arrayShape arr) (f . arrayIndexLinear arr) + +-- | The Int is the linear index of the value. +traverseArray_ :: Monad m => (Int -> t -> m ()) -> Array n t -> m () +traverseArray_ f (Array _ v) = evalStateT (traverse_ (\x -> StateT (\i -> (,i+1) <$> f i x)) v) 0 |
