From 3e74b0673caba7c04353c0cedb1d6e02de1fd007 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 14 Apr 2024 10:12:41 +0200 Subject: Move from unboxed to storable vectors Mikolaj requires this to interface with hmatrix --- src/Data/Array/Nested/Internal.hs | 43 ++++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 21 deletions(-) (limited to 'src/Data/Array/Nested/Internal.hs') diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index bdded69..41fb1fd 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -35,8 +35,9 @@ import Data.Coerce (coerce, Coercible) import Data.Kind import Data.Proxy import Data.Type.Equality -import qualified Data.Vector.Unboxed as VU -import qualified Data.Vector.Unboxed.Mutable as VUM +import qualified Data.Vector.Storable as VS +import qualified Data.Vector.Storable.Mutable as VSM +import Foreign.Storable (Storable) import qualified GHC.TypeLits as GHC import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++), pattern GHC_SNat) @@ -119,11 +120,11 @@ deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a)) type MixedVecs :: Type -> [Maybe GHC.Nat] -> Type -> Type data family MixedVecs s sh a -newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VU.MVector s a) +newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a) -newtype instance MixedVecs s sh Int = MV_Int (VU.MVector s Int) -newtype instance MixedVecs s sh Double = MV_Double (VU.MVector s Double) -newtype instance MixedVecs s sh () = MV_Nil (VU.MVector s ()) -- no content, MVector optimises this +newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int) +newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double) +newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MVector optimises this -- etc. data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh b) @@ -143,7 +144,7 @@ class Elt a where mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 a -> Mixed sh2 a -- ====== PRIVATE METHODS ====== -- @@ -176,7 +177,7 @@ class Elt a where -- Arrays of scalars are basically just arrays of scalars. -instance VU.Unbox a => Elt (Primitive a) where +instance Storable a => Elt (Primitive a) where mshape (M_Primitive a) = X.shape a mindex (M_Primitive a) i = Primitive (X.index a i) mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i) @@ -191,18 +192,18 @@ instance VU.Unbox a => Elt (Primitive a) where memptyArray sh = M_Primitive (X.generate sh (error "memptyArray Int: shape was not empty")) mvecsNumElts _ = 1 - mvecsUnsafeNew sh _ = MV_Primitive <$> VUM.unsafeNew (X.shapeSize sh) - mvecsWrite sh i (Primitive x) (MV_Primitive v) = VUM.write v (X.toLinearIdx sh i) x + mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh) + mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x -- TODO: this use of toVector is suboptimal mvecsWritePartial - :: forall sh' sh s. (KnownShapeX sh', VU.Unbox a) + :: forall sh' sh s. KnownShapeX sh' => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr))) - VU.copy (VUM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr) + VS.copy (VSM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr) - mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VU.freeze v + mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VS.freeze v deriving via Primitive Int instance Elt Int deriving via Primitive Double instance Elt Double @@ -247,13 +248,13 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray (sh1 ++ sh3) b -> XArray (sh2 ++ sh3) b) + => (forall sh3 b. (KnownShapeX sh3, Storable b) => Proxy sh3 -> XArray (sh1 ++ sh3) b -> XArray (sh2 ++ sh3) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) mlift f (M_Nest arr) | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) = M_Nest (mlift f' arr) where - f' :: forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray ((sh1 ++ sh') ++ sh3) b -> XArray ((sh2 ++ sh') ++ sh3) b + f' :: forall sh3 b. (KnownShapeX sh3, Storable b) => Proxy sh3 -> XArray ((sh1 ++ sh') ++ sh3) b -> XArray ((sh2 ++ sh') ++ sh3) b f' _ | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @sh3) , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @sh3) @@ -374,7 +375,7 @@ instance (KnownNat n, Elt a) => Elt (Ranked n a) where mindexPartial arr i mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) mlift f (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) @@ -469,7 +470,7 @@ instance (KnownShape sh, Elt a) => Elt (Shaped sh a) where mindexPartial arr i mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) mlift f (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) @@ -568,14 +569,14 @@ rgenerate sh f = Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR)) rlift :: forall n1 n2 a. (KnownNat n2, Elt a) - => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) + => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) -> Ranked n1 a -> Ranked n2 a rlift f (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n2) = Ranked (mlift f arr) rsumOuter1 :: forall n a. - (VU.Unbox a, Num a, KnownNat n, forall sh. Coercible (Mixed sh a) (XArray sh a)) + (Storable a, Num a, KnownNat n, forall sh. Coercible (Mixed sh a) (XArray sh a)) => Ranked (S n) a -> Ranked n a rsumOuter1 (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) @@ -636,14 +637,14 @@ sgenerate sh f = Shaped (mgenerate (ixCvtSX sh) (f . ixCvtXS (knownShape @sh))) slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a) - => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) + => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a slift f (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh2) = Shaped (mlift f arr) ssumOuter1 :: forall sh n a. - (VU.Unbox a, Num a, GHC.KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a)) + (Storable a, Num a, GHC.KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a)) => Shaped (n : sh) a -> Shaped sh a ssumOuter1 (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) -- cgit v1.2.3-70-g09d2