diff options
Diffstat (limited to 'ops/Data/Array/Strided/Array.hs')
-rw-r--r-- | ops/Data/Array/Strided/Array.hs | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/ops/Data/Array/Strided/Array.hs b/ops/Data/Array/Strided/Array.hs new file mode 100644 index 0000000..a772aaf --- /dev/null +++ b/ops/Data/Array/Strided/Array.hs @@ -0,0 +1,42 @@ +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +module Data.Array.Strided.Array where + +import qualified Data.List.NonEmpty as NE +import Data.Proxy +import qualified Data.Vector.Storable as VS +import Foreign.Storable +import GHC.TypeLits + + +data Array (n :: Nat) a = Array + { arrShape :: ![Int] + , arrStrides :: ![Int] + , arrOffset :: !Int + , arrValues :: !(VS.Vector a) + } + +-- | Takes a vector in normalised order (inner dimension, i.e. last in the +-- list, iterates fastest). +arrayFromVector :: forall a n. (Storable a, KnownNat n) => [Int] -> VS.Vector a -> Array n a +arrayFromVector sh vec + | VS.length vec == shsize + , length sh == fromIntegral (natVal (Proxy @n)) + = Array sh strides 0 vec + | otherwise = error $ "arrayFromVector: Shape " ++ show sh ++ " does not match vector length " ++ show (VS.length vec) + where + shsize = product sh + strides = NE.tail (NE.scanr (*) 1 sh) + +arrayFromConstant :: (Storable a, KnownNat n) => [Int] -> a -> Array n a +arrayFromConstant sh x = Array sh (0 <$ sh) 0 (VS.singleton x) + +arrayRevDims :: [Bool] -> Array n a -> Array n a +arrayRevDims bs (Array sh strides offset vec) + | length bs == length sh = + Array sh + (zipWith (\b s -> if b then -s else s) bs strides) + (offset + sum (zipWith3 (\b n s -> if b then (n - 1) * s else 0) bs sh strides)) + vec + | otherwise = error $ "arrayRevDims: " ++ show (length bs) ++ " booleans given but rank " ++ show (length sh) |