aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Data/Array/Nested/Mixed.hs72
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs3
-rw-r--r--src/Data/Array/Nested/Ranked/Base.hs23
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs23
4 files changed, 79 insertions, 42 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 182943d..fc1c108 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -367,18 +367,17 @@ class Elt a where
-- this mixed array.
marrayStrides :: Mixed sh a -> Bag [Int]
- -- | Given the shape of this array, an index and a value, write the value at
+ -- | Given a linear index and a value, write the value at
-- that index in the vectors.
- mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
+ mvecsWriteLinear :: Int -> a -> MixedVecs s sh a -> ST s ()
- -- | Given the shape of this array, an index and a value, write the value at
+ -- | Given a linear index and a value, write the value at
-- that index in the vectors.
- mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+ mvecsWritePartialLinear :: Proxy sh -> Int -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
-- | Given the shape of this array, finalise the vectors into 'XArray's.
mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
-
-- | Element types for which we have evidence of the (static part of the) shape
-- in a type class constraint. Compare the instance contexts of the instances
-- of this class with those of 'Elt': some instances have an additional
@@ -393,6 +392,10 @@ class Elt a => KnownElt a where
-- this vector and an example for the contents.
mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a)
+ -- | Create initialised vectors for this array type, given the shape of
+ -- this vector and the chosen element.
+ mvecsReplicate :: IShX sh -> a -> ST s (MixedVecs s sh a)
+
mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a)
@@ -464,15 +467,15 @@ instance Storable a => Elt (Primitive a) where
mshapeTreeIsEmpty _ () = False
mshowShapeTree _ () = "()"
marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr)
- mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x
+ mvecsWriteLinear i (Primitive x) (MV_Primitive v) = VSM.write v i x
-- TODO: this use of toVector is suboptimal
- mvecsWritePartial
+ mvecsWritePartialLinear
:: forall sh' sh s.
- IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
- mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do
+ Proxy sh -> Int -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
+ mvecsWritePartialLinear _ i (M_Primitive sh' arr) (MV_Primitive v) = do
let arrsh = X.shape (ssxFromShX sh') arr
- offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh))
+ offset = i * shxSize arrsh
VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr)
mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v
@@ -492,6 +495,7 @@ deriving via Primitive () instance Elt ()
instance Storable a => KnownElt (Primitive a) where
memptyArrayUnsafe sh = M_Primitive sh (X.empty sh)
mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh)
+ mvecsReplicate sh (Primitive a) = MV_Primitive <$> VSM.replicate (shxSize sh) a
mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0
-- [PRIMITIVE ELEMENT TYPES LIST]
@@ -542,17 +546,18 @@ instance (Elt a, Elt b) => Elt (a, b) where
mshapeTreeIsEmpty _ (t1, t2) = mshapeTreeIsEmpty (Proxy @a) t1 && mshapeTreeIsEmpty (Proxy @b) t2
mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"
marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b
- mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
- mvecsWrite sh i x a
- mvecsWrite sh i y b
- mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do
- mvecsWritePartial sh i x a
- mvecsWritePartial sh i y b
+ mvecsWriteLinear i (x, y) (MV_Tup2 a b) = do
+ mvecsWriteLinear i x a
+ mvecsWriteLinear i y b
+ mvecsWritePartialLinear proxy i (M_Tup2 x y) (MV_Tup2 a b) = do
+ mvecsWritePartialLinear proxy i x a
+ mvecsWritePartialLinear proxy i y b
mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
instance (KnownElt a, KnownElt b) => KnownElt (a, b) where
memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh)
mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
+ mvecsReplicate sh (x, y) = MV_Tup2 <$> mvecsReplicate sh x <*> mvecsReplicate sh y
mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b)
-- Arrays of arrays are just arrays, but with more dimensions.
@@ -676,15 +681,17 @@ instance Elt a => Elt (Mixed sh' a) where
marrayStrides (M_Nest _ arr) = marrayStrides arr
- mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs
+ mvecsWriteLinear :: forall s sh. Int -> Mixed sh' a -> MixedVecs s sh (Mixed sh' a) -> ST s ()
+ mvecsWriteLinear idx val (MV_Nest _ vecs) = mvecsWritePartialLinear (Proxy @sh) idx val vecs
- mvecsWritePartial :: forall sh1 sh2 s.
- IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
- -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
- -> ST s ()
- mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs)
+ mvecsWritePartialLinear
+ :: forall sh1 sh2 s.
+ Proxy sh1 -> Int -> Mixed sh2 (Mixed sh' a)
+ -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
+ -> ST s ()
+ mvecsWritePartialLinear proxy idx (M_Nest _ arr) (MV_Nest _ vecs)
| Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs
+ = mvecsWritePartialLinear proxy idx arr vecs
mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs
@@ -697,9 +704,28 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
where
sh' = mshape example
+ mvecsReplicate sh example = do
+ vecs <- mvecsUnsafeNew sh example
+ forM_ (shxEnum sh) $ \idx -> mvecsWrite sh idx example vecs
+ -- this is a slow case, but the alternative, mvecsUnsafeNew with manual
+ -- writing in a loop, leads to every case being as slow
+ return vecs
+
mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
+-- | Given the shape of this array, an index and a value, write the value at
+-- that index in the vectors.
+{-# INLINE mvecsWrite #-}
+mvecsWrite :: Elt a => IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
+mvecsWrite sh idx val vecs = mvecsWriteLinear (ixxToLinear sh idx) val vecs
+
+-- | Given the shape of this array, an index and a value, write the value at
+-- that index in the vectors.
+{-# INLINE mvecsWritePartial #-}
+mvecsWritePartial :: forall sh sh' s a. Elt a => IShX sh -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+mvecsWritePartial sh idx val vecs = mvecsWritePartialLinear (Proxy @sh) (ixxToLinear sh idx) val vecs
+
-- TODO: should we provide a function that's just memptyArrayUnsafe but with a size==0 check? That may save someone a transpose somewhere
memptyArray :: forall sh a. KnownElt a => IShX sh -> Mixed (Just 0 : sh) a
memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh)
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index c999853..145ea5f 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -36,7 +36,7 @@ import Data.Functor.Product
import Data.Kind (Constraint, Type)
import Data.Monoid (Sum(..))
import Data.Type.Equality
-import GHC.Exts (Int(..), Int#, quotRemInt#, withDict, build)
+import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict)
import GHC.Generics (Generic)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
@@ -284,6 +284,7 @@ ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k
ixxZipWith _ ZIX ZIX = ZIX
ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js
+{-# INLINEABLE ixxToLinear #-}
ixxToLinear :: IShX sh -> IIxX sh -> Int
ixxToLinear = \sh i -> fst (go sh i)
where
diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs
index 11a8ffb..ed194a8 100644
--- a/src/Data/Array/Nested/Ranked/Base.hs
+++ b/src/Data/Array/Nested/Ranked/Base.hs
@@ -149,18 +149,19 @@ instance Elt a => Elt (Ranked n a) where
marrayStrides (M_Ranked arr) = marrayStrides arr
- mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
- mvecsWrite sh idx (Ranked arr) vecs =
- mvecsWrite sh idx arr
+ mvecsWriteLinear :: forall sh s. Int -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
+ mvecsWriteLinear idx (Ranked arr) vecs =
+ mvecsWriteLinear idx arr
(coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
vecs)
- mvecsWritePartial :: forall sh sh' s.
- IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
- -> MixedVecs s (sh ++ sh') (Ranked n a)
- -> ST s ()
- mvecsWritePartial sh idx arr vecs =
- mvecsWritePartial sh idx
+ mvecsWritePartialLinear
+ :: forall sh sh' s.
+ Proxy sh -> Int -> Mixed sh' (Ranked n a)
+ -> MixedVecs s (sh ++ sh') (Ranked n a)
+ -> ST s ()
+ mvecsWritePartialLinear proxy idx arr vecs =
+ mvecsWritePartialLinear proxy idx
(coerce @(Mixed sh' (Ranked n a))
@(Mixed sh' (Mixed (Replicate n Nothing) a))
arr)
@@ -188,6 +189,10 @@ instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
| Dict <- lemKnownReplicate (SNat @n)
= MV_Ranked <$> mvecsUnsafeNew idx arr
+ mvecsReplicate idx (Ranked arr)
+ | Dict <- lemKnownReplicate (SNat @n)
+ = MV_Ranked <$> mvecsReplicate idx arr
+
mvecsNewEmpty _
| Dict <- lemKnownReplicate (SNat @n)
= MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
index 98f1241..e5dd852 100644
--- a/src/Data/Array/Nested/Shaped/Base.hs
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -142,18 +142,19 @@ instance Elt a => Elt (Shaped sh a) where
marrayStrides (M_Shaped arr) = marrayStrides arr
- mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
- mvecsWrite sh idx (Shaped arr) vecs =
- mvecsWrite sh idx arr
+ mvecsWriteLinear :: forall sh' s. Int -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
+ mvecsWriteLinear idx (Shaped arr) vecs =
+ mvecsWriteLinear idx arr
(coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
vecs)
- mvecsWritePartial :: forall sh1 sh2 s.
- IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
- -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
- -> ST s ()
- mvecsWritePartial sh idx arr vecs =
- mvecsWritePartial sh idx
+ mvecsWritePartialLinear
+ :: forall sh1 sh2 s.
+ Proxy sh1 -> Int -> Mixed sh2 (Shaped sh a)
+ -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
+ -> ST s ()
+ mvecsWritePartialLinear proxy idx arr vecs =
+ mvecsWritePartialLinear proxy idx
(coerce @(Mixed sh2 (Shaped sh a))
@(Mixed sh2 (Mixed (MapJust sh) a))
arr)
@@ -181,6 +182,10 @@ instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
| Dict <- lemKnownMapJust (Proxy @sh)
= MV_Shaped <$> mvecsUnsafeNew idx arr
+ mvecsReplicate idx (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsReplicate idx arr
+
mvecsNewEmpty _
| Dict <- lemKnownMapJust (Proxy @sh)
= MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))