{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TupleSections #-} module Array where import Control.DeepSeq import Control.Monad.Trans.State.Strict import Data.Foldable (traverse_) import Data.Vector (Vector) import qualified Data.Vector as V import GHC.Generics (Generic) import Data data Shape n where ShNil :: Shape Z ShCons :: Shape n -> Int -> Shape (S n) deriving instance Show (Shape n) deriving instance Eq (Shape n) instance NFData (Shape n) where rnf ShNil = () rnf (sh `ShCons` n) = rnf n `seq` rnf sh data Index n where IxNil :: Index Z IxCons :: Index n -> Int -> Index (S n) deriving instance Show (Index n) deriving instance Eq (Index n) instance NFData (Index n) where rnf IxNil = () rnf (sh `IxCons` n) = rnf n `seq` rnf sh shapeSize :: Shape n -> Int shapeSize ShNil = 1 shapeSize (ShCons sh n) = shapeSize sh * n fromLinearIndex :: Shape n -> Int -> Index n fromLinearIndex ShNil 0 = IxNil fromLinearIndex ShNil _ = error "Index out of range" fromLinearIndex (sh `ShCons` n) i = let (q, r) = i `quotRem` n in fromLinearIndex sh q `IxCons` r toLinearIndex :: Shape n -> Index n -> Int toLinearIndex ShNil IxNil = 0 toLinearIndex (sh `ShCons` n) (idx `IxCons` i) = toLinearIndex sh idx * n + i emptyShape :: SNat n -> Shape n emptyShape SZ = ShNil emptyShape (SS m) = emptyShape m `ShCons` 0 enumShape :: Shape n -> [Index n] enumShape sh = map (fromLinearIndex sh) [0 .. shapeSize sh - 1] -- | TODO: this Vector is a boxed vector, which is horrendously inefficient. data Array (n :: Nat) t = Array (Shape n) (Vector t) deriving (Show, Functor, Foldable, Traversable, Generic) instance NFData t => NFData (Array n t) arrayShape :: Array n t -> Shape n arrayShape (Array sh _) = sh arraySize :: Array n t -> Int arraySize (Array sh _) = shapeSize sh emptyArray :: SNat n -> Array n t emptyArray n = Array (emptyShape n) V.empty arrayFromList :: Shape n -> [t] -> Array n t arrayFromList sh l = Array sh (V.fromListN (shapeSize sh) l) arrayUnit :: t -> Array Z t arrayUnit x = Array ShNil (V.singleton x) arrayIndex :: Array n t -> Index n -> t arrayIndex arr@(Array sh _) idx = arrayIndexLinear arr (toLinearIndex sh idx) arrayIndexLinear :: Array n t -> Int -> t arrayIndexLinear (Array _ v) i = v V.! i arrayIndex1 :: Array (S n) t -> Int -> Array n t arrayIndex1 (Array (sh `ShCons` _) v) i = let sz = shapeSize sh in Array sh (V.slice (sz * i) sz v) arrayGenerate :: Shape n -> (Index n -> t) -> Array n t arrayGenerate sh f = arrayGenerateLin sh (f . fromLinearIndex sh) arrayGenerateLin :: Shape n -> (Int -> t) -> Array n t arrayGenerateLin sh f = Array sh (V.generate (shapeSize sh) f) arrayGenerateM :: Monad m => Shape n -> (Index n -> m t) -> m (Array n t) arrayGenerateM sh f = arrayGenerateLinM sh (f . fromLinearIndex sh) arrayGenerateLinM :: Monad m => Shape n -> (Int -> m t) -> m (Array n t) arrayGenerateLinM sh f = Array sh <$> V.generateM (shapeSize sh) f arrayMap :: (a -> b) -> Array n a -> Array n b arrayMap f arr = arrayGenerateLin (arrayShape arr) (f . arrayIndexLinear arr) arrayMapM :: Monad m => (a -> m b) -> Array n a -> m (Array n b) arrayMapM f arr = arrayGenerateLinM (arrayShape arr) (f . arrayIndexLinear arr) -- | The Int is the linear index of the value. traverseArray_ :: Monad m => (Int -> t -> m ()) -> Array n t -> m () traverseArray_ f (Array _ v) = evalStateT (traverse_ (\x -> StateT (\i -> (,i+1) <$> f i x)) v) 0