diff options
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 72 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 3 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Base.hs | 23 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 23 |
4 files changed, 79 insertions, 42 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 182943d..fc1c108 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -367,18 +367,17 @@ class Elt a where -- this mixed array. marrayStrides :: Mixed sh a -> Bag [Int] - -- | Given the shape of this array, an index and a value, write the value at + -- | Given a linear index and a value, write the value at -- that index in the vectors. - mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () + mvecsWriteLinear :: Int -> a -> MixedVecs s sh a -> ST s () - -- | Given the shape of this array, an index and a value, write the value at + -- | Given a linear index and a value, write the value at -- that index in the vectors. - mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + mvecsWritePartialLinear :: Proxy sh -> Int -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () -- | Given the shape of this array, finalise the vectors into 'XArray's. mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) - -- | Element types for which we have evidence of the (static part of the) shape -- in a type class constraint. Compare the instance contexts of the instances -- of this class with those of 'Elt': some instances have an additional @@ -393,6 +392,10 @@ class Elt a => KnownElt a where -- this vector and an example for the contents. mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) + -- | Create initialised vectors for this array type, given the shape of + -- this vector and the chosen element. + mvecsReplicate :: IShX sh -> a -> ST s (MixedVecs s sh a) + mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) @@ -464,15 +467,15 @@ instance Storable a => Elt (Primitive a) where mshapeTreeIsEmpty _ () = False mshowShapeTree _ () = "()" marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr) - mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x + mvecsWriteLinear i (Primitive x) (MV_Primitive v) = VSM.write v i x -- TODO: this use of toVector is suboptimal - mvecsWritePartial + mvecsWritePartialLinear :: forall sh' sh s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () - mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do + Proxy sh -> Int -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () + mvecsWritePartialLinear _ i (M_Primitive sh' arr) (MV_Primitive v) = do let arrsh = X.shape (ssxFromShX sh') arr - offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) + offset = i * shxSize arrsh VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v @@ -492,6 +495,7 @@ deriving via Primitive () instance Elt () instance Storable a => KnownElt (Primitive a) where memptyArrayUnsafe sh = M_Primitive sh (X.empty sh) mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh) + mvecsReplicate sh (Primitive a) = MV_Primitive <$> VSM.replicate (shxSize sh) a mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 -- [PRIMITIVE ELEMENT TYPES LIST] @@ -542,17 +546,18 @@ instance (Elt a, Elt b) => Elt (a, b) where mshapeTreeIsEmpty _ (t1, t2) = mshapeTreeIsEmpty (Proxy @a) t1 && mshapeTreeIsEmpty (Proxy @b) t2 mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b - mvecsWrite sh i (x, y) (MV_Tup2 a b) = do - mvecsWrite sh i x a - mvecsWrite sh i y b - mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do - mvecsWritePartial sh i x a - mvecsWritePartial sh i y b + mvecsWriteLinear i (x, y) (MV_Tup2 a b) = do + mvecsWriteLinear i x a + mvecsWriteLinear i y b + mvecsWritePartialLinear proxy i (M_Tup2 x y) (MV_Tup2 a b) = do + mvecsWritePartialLinear proxy i x a + mvecsWritePartialLinear proxy i y b mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b instance (KnownElt a, KnownElt b) => KnownElt (a, b) where memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh) mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y + mvecsReplicate sh (x, y) = MV_Tup2 <$> mvecsReplicate sh x <*> mvecsReplicate sh y mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) -- Arrays of arrays are just arrays, but with more dimensions. @@ -676,15 +681,17 @@ instance Elt a => Elt (Mixed sh' a) where marrayStrides (M_Nest _ arr) = marrayStrides arr - mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs + mvecsWriteLinear :: forall s sh. Int -> Mixed sh' a -> MixedVecs s sh (Mixed sh' a) -> ST s () + mvecsWriteLinear idx val (MV_Nest _ vecs) = mvecsWritePartialLinear (Proxy @sh) idx val vecs - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) - -> ST s () - mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) + mvecsWritePartialLinear + :: forall sh1 sh2 s. + Proxy sh1 -> Int -> Mixed sh2 (Mixed sh' a) + -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) + -> ST s () + mvecsWritePartialLinear proxy idx (M_Nest _ arr) (MV_Nest _ vecs) | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs + = mvecsWritePartialLinear proxy idx arr vecs mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs @@ -697,9 +704,28 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where where sh' = mshape example + mvecsReplicate sh example = do + vecs <- mvecsUnsafeNew sh example + forM_ (shxEnum sh) $ \idx -> mvecsWrite sh idx example vecs + -- this is a slow case, but the alternative, mvecsUnsafeNew with manual + -- writing in a loop, leads to every case being as slow + return vecs + mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) +-- | Given the shape of this array, an index and a value, write the value at +-- that index in the vectors. +{-# INLINE mvecsWrite #-} +mvecsWrite :: Elt a => IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () +mvecsWrite sh idx val vecs = mvecsWriteLinear (ixxToLinear sh idx) val vecs + +-- | Given the shape of this array, an index and a value, write the value at +-- that index in the vectors. +{-# INLINE mvecsWritePartial #-} +mvecsWritePartial :: forall sh sh' s a. Elt a => IShX sh -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () +mvecsWritePartial sh idx val vecs = mvecsWritePartialLinear (Proxy @sh) (ixxToLinear sh idx) val vecs + -- TODO: should we provide a function that's just memptyArrayUnsafe but with a size==0 check? That may save someone a transpose somewhere memptyArray :: forall sh a. KnownElt a => IShX sh -> Mixed (Just 0 : sh) a memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh) diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index c999853..145ea5f 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -36,7 +36,7 @@ import Data.Functor.Product import Data.Kind (Constraint, Type) import Data.Monoid (Sum(..)) import Data.Type.Equality -import GHC.Exts (Int(..), Int#, quotRemInt#, withDict, build) +import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict) import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList @@ -284,6 +284,7 @@ ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k ixxZipWith _ ZIX ZIX = ZIX ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js +{-# INLINEABLE ixxToLinear #-} ixxToLinear :: IShX sh -> IIxX sh -> Int ixxToLinear = \sh i -> fst (go sh i) where diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs index 11a8ffb..ed194a8 100644 --- a/src/Data/Array/Nested/Ranked/Base.hs +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -149,18 +149,19 @@ instance Elt a => Elt (Ranked n a) where marrayStrides (M_Ranked arr) = marrayStrides arr - mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () - mvecsWrite sh idx (Ranked arr) vecs = - mvecsWrite sh idx arr + mvecsWriteLinear :: forall sh s. Int -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () + mvecsWriteLinear idx (Ranked arr) vecs = + mvecsWriteLinear idx arr (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) vecs) - mvecsWritePartial :: forall sh sh' s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) - -> MixedVecs s (sh ++ sh') (Ranked n a) - -> ST s () - mvecsWritePartial sh idx arr vecs = - mvecsWritePartial sh idx + mvecsWritePartialLinear + :: forall sh sh' s. + Proxy sh -> Int -> Mixed sh' (Ranked n a) + -> MixedVecs s (sh ++ sh') (Ranked n a) + -> ST s () + mvecsWritePartialLinear proxy idx arr vecs = + mvecsWritePartialLinear proxy idx (coerce @(Mixed sh' (Ranked n a)) @(Mixed sh' (Mixed (Replicate n Nothing) a)) arr) @@ -188,6 +189,10 @@ instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where | Dict <- lemKnownReplicate (SNat @n) = MV_Ranked <$> mvecsUnsafeNew idx arr + mvecsReplicate idx (Ranked arr) + | Dict <- lemKnownReplicate (SNat @n) + = MV_Ranked <$> mvecsReplicate idx arr + mvecsNewEmpty _ | Dict <- lemKnownReplicate (SNat @n) = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index 98f1241..e5dd852 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -142,18 +142,19 @@ instance Elt a => Elt (Shaped sh a) where marrayStrides (M_Shaped arr) = marrayStrides arr - mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () - mvecsWrite sh idx (Shaped arr) vecs = - mvecsWrite sh idx arr + mvecsWriteLinear :: forall sh' s. Int -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () + mvecsWriteLinear idx (Shaped arr) vecs = + mvecsWriteLinear idx arr (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) vecs) - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) - -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) - -> ST s () - mvecsWritePartial sh idx arr vecs = - mvecsWritePartial sh idx + mvecsWritePartialLinear + :: forall sh1 sh2 s. + Proxy sh1 -> Int -> Mixed sh2 (Shaped sh a) + -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) + -> ST s () + mvecsWritePartialLinear proxy idx arr vecs = + mvecsWritePartialLinear proxy idx (coerce @(Mixed sh2 (Shaped sh a)) @(Mixed sh2 (Mixed (MapJust sh) a)) arr) @@ -181,6 +182,10 @@ instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where | Dict <- lemKnownMapJust (Proxy @sh) = MV_Shaped <$> mvecsUnsafeNew idx arr + mvecsReplicate idx (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsReplicate idx arr + mvecsNewEmpty _ | Dict <- lemKnownMapJust (Proxy @sh) = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) |
