aboutsummaryrefslogtreecommitdiff
path: root/ops/Data/Array/Strided/Array.hs
diff options
context:
space:
mode:
Diffstat (limited to 'ops/Data/Array/Strided/Array.hs')
-rw-r--r--ops/Data/Array/Strided/Array.hs42
1 files changed, 42 insertions, 0 deletions
diff --git a/ops/Data/Array/Strided/Array.hs b/ops/Data/Array/Strided/Array.hs
new file mode 100644
index 0000000..a772aaf
--- /dev/null
+++ b/ops/Data/Array/Strided/Array.hs
@@ -0,0 +1,42 @@
+{-# LANGUAGE KindSignatures #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+module Data.Array.Strided.Array where
+
+import qualified Data.List.NonEmpty as NE
+import Data.Proxy
+import qualified Data.Vector.Storable as VS
+import Foreign.Storable
+import GHC.TypeLits
+
+
+data Array (n :: Nat) a = Array
+ { arrShape :: ![Int]
+ , arrStrides :: ![Int]
+ , arrOffset :: !Int
+ , arrValues :: !(VS.Vector a)
+ }
+
+-- | Takes a vector in normalised order (inner dimension, i.e. last in the
+-- list, iterates fastest).
+arrayFromVector :: forall a n. (Storable a, KnownNat n) => [Int] -> VS.Vector a -> Array n a
+arrayFromVector sh vec
+ | VS.length vec == shsize
+ , length sh == fromIntegral (natVal (Proxy @n))
+ = Array sh strides 0 vec
+ | otherwise = error $ "arrayFromVector: Shape " ++ show sh ++ " does not match vector length " ++ show (VS.length vec)
+ where
+ shsize = product sh
+ strides = NE.tail (NE.scanr (*) 1 sh)
+
+arrayFromConstant :: (Storable a, KnownNat n) => [Int] -> a -> Array n a
+arrayFromConstant sh x = Array sh (0 <$ sh) 0 (VS.singleton x)
+
+arrayRevDims :: [Bool] -> Array n a -> Array n a
+arrayRevDims bs (Array sh strides offset vec)
+ | length bs == length sh =
+ Array sh
+ (zipWith (\b s -> if b then -s else s) bs strides)
+ (offset + sum (zipWith3 (\b n s -> if b then (n - 1) * s else 0) bs sh strides))
+ vec
+ | otherwise = error $ "arrayRevDims: " ++ show (length bs) ++ " booleans given but rank " ++ show (length sh)