summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-04-14 10:12:41 +0200
committerTom Smeding <tom@tomsmeding.com>2024-04-14 10:12:41 +0200
commit3e74b0673caba7c04353c0cedb1d6e02de1fd007 (patch)
tree36e700e86199dd07b54cd3c28e8d09dc77f32c42
parent8a81f7ea9eed9afaec948910caaf0a5c498de6c6 (diff)
Move from unboxed to storable vectors
Mikolaj requires this to interface with hmatrix
-rw-r--r--src/Data/Array/Mixed.hs72
-rw-r--r--src/Data/Array/Nested/Internal.hs43
2 files changed, 58 insertions, 57 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 2875203..2bbf81d 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -12,11 +12,11 @@
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed where
-import qualified Data.Array.RankedU as U
+import qualified Data.Array.RankedS as S
import Data.Kind
import Data.Proxy
import Data.Type.Equality
-import qualified Data.Vector.Unboxed as VU
+import qualified Data.Vector.Storable as VS
import qualified GHC.TypeLits as GHC
import Unsafe.Coerce (unsafeCoerce)
@@ -77,7 +77,7 @@ type family Rank sh where
Rank (_ : sh) = S (Rank sh)
type XArray :: [Maybe GHC.Nat] -> Type -> Type
-data XArray sh a = XArray (U.Array (GNat (Rank sh)) a)
+data XArray sh a = XArray (S.Array (GNat (Rank sh)) a)
deriving (Show)
zeroIdx :: StaticShapeX sh -> IxX sh
@@ -149,7 +149,7 @@ enumShape = \sh -> go sh id []
go (n ::@ sh) f = foldr (.) id [go sh (f . (i ::@)) | i <- [0 .. n-1]]
go (n ::? sh) f = foldr (.) id [go sh (f . (i ::?)) | i <- [0 .. n-1]]
-shapeLshape :: IxX sh -> U.ShapeL
+shapeLshape :: IxX sh -> S.ShapeL
shapeLshape IZX = []
shapeLshape (n ::@ sh) = n : shapeLshape sh
shapeLshape (n ::? sh) = n : shapeLshape sh
@@ -197,7 +197,7 @@ lemAppKnownShapeX (() :$? ssh) ssh'
= Dict
shape :: forall sh a. KnownShapeX sh => XArray sh a -> IxX sh
-shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr)
+shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr)
where
go :: StaticShapeX sh' -> [Int] -> IxX sh'
go SZX [] = IZX
@@ -205,48 +205,48 @@ shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr)
go (() :$? ssh) (n : l) = n ::? go ssh l
go _ _ = error "Invalid shapeL"
-fromVector :: forall sh a. U.Unbox a => IxX sh -> VU.Vector a -> XArray sh a
+fromVector :: forall sh a. S.Unbox a => IxX sh -> VS.Vector a -> XArray sh a
fromVector sh v
| Dict <- lemKnownNatRank sh
, Dict <- gknownNat (Proxy @(Rank sh))
- = XArray (U.fromVector (shapeLshape sh) v)
+ = XArray (S.fromVector (shapeLshape sh) v)
-toVector :: U.Unbox a => XArray sh a -> VU.Vector a
-toVector (XArray arr) = U.toVector arr
+toVector :: S.Unbox a => XArray sh a -> VS.Vector a
+toVector (XArray arr) = S.toVector arr
-scalar :: U.Unbox a => a -> XArray '[] a
-scalar = XArray . U.scalar
+scalar :: S.Unbox a => a -> XArray '[] a
+scalar = XArray . S.scalar
-unScalar :: U.Unbox a => XArray '[] a -> a
-unScalar (XArray a) = U.unScalar a
+unScalar :: S.Unbox a => XArray '[] a -> a
+unScalar (XArray a) = S.unScalar a
-generate :: U.Unbox a => IxX sh -> (IxX sh -> a) -> XArray sh a
-generate sh f = fromVector sh $ VU.generate (shapeSize sh) (f . fromLinearIdx sh)
+generate :: S.Unbox a => IxX sh -> (IxX sh -> a) -> XArray sh a
+generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh)
--- generateM :: (Monad m, U.Unbox a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a)
+-- generateM :: (Monad m, S.Unbox a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a)
-- generateM sh f | Dict <- lemKnownNatRank sh =
--- XArray . U.fromVector (shapeLshape sh)
--- <$> VU.generateM (shapeSize sh) (f . fromLinearIdx sh)
+-- XArray . S.fromVector (shapeLshape sh)
+-- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh)
-indexPartial :: U.Unbox a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a
+indexPartial :: S.Unbox a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a
indexPartial (XArray arr) IZX = XArray arr
-indexPartial (XArray arr) (i ::@ idx) = indexPartial (XArray (U.index arr i)) idx
-indexPartial (XArray arr) (i ::? idx) = indexPartial (XArray (U.index arr i)) idx
+indexPartial (XArray arr) (i ::@ idx) = indexPartial (XArray (S.index arr i)) idx
+indexPartial (XArray arr) (i ::? idx) = indexPartial (XArray (S.index arr i)) idx
-index :: forall sh a. U.Unbox a => XArray sh a -> IxX sh -> a
+index :: forall sh a. S.Unbox a => XArray sh a -> IxX sh -> a
index xarr i
| Refl <- lemAppNil @sh
= let XArray arr' = indexPartial xarr i :: XArray '[] a
- in U.unScalar arr'
+ in S.unScalar arr'
-append :: forall sh a. (KnownShapeX sh, U.Unbox a) => XArray sh a -> XArray sh a -> XArray sh a
+append :: forall sh a. (KnownShapeX sh, S.Unbox a) => XArray sh a -> XArray sh a -> XArray sh a
append (XArray a) (XArray b)
| Dict <- lemKnownNatRankSSX (knownShapeX @sh)
, Dict <- gknownNat (Proxy @(Rank sh))
- = XArray (U.append a b)
+ = XArray (S.append a b)
rerank :: forall sh sh1 sh2 a b.
- (U.Unbox a, U.Unbox b)
+ (S.Unbox a, S.Unbox b)
=> StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2
-> (XArray sh1 a -> XArray sh2 b)
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b
@@ -259,21 +259,21 @@ rerank ssh ssh1 ssh2 f (XArray arr)
, Refl <- lemRankApp ssh ssh2
, Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
, Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
- = XArray (U.rerank @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2))
+ = XArray (S.rerank @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2))
(\a -> unXArray (f (XArray a)))
arr)
where
unXArray (XArray a) = a
rerankTop :: forall sh sh1 sh2 a b.
- (U.Unbox a, U.Unbox b)
+ (S.Unbox a, S.Unbox b)
=> StaticShapeX sh1 -> StaticShapeX sh2 -> StaticShapeX sh
-> (XArray sh1 a -> XArray sh2 b)
-> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b
rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh
rerank2 :: forall sh sh1 sh2 a b c.
- (U.Unbox a, U.Unbox b, U.Unbox c)
+ (S.Unbox a, S.Unbox b, S.Unbox c)
=> StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2
-> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c)
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c
@@ -286,7 +286,7 @@ rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2)
, Refl <- lemRankApp ssh ssh2
, Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
, Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
- = XArray (U.rerank2 @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2))
+ = XArray (S.rerank2 @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2))
(\a b -> unXArray (f (XArray a) (XArray b)))
arr1 arr2)
where
@@ -297,7 +297,7 @@ transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a
transpose perm (XArray arr)
| Dict <- lemKnownNatRankSSX (knownShapeX @sh)
, Dict <- gknownNat (Proxy @(Rank sh))
- = XArray (U.transpose perm arr)
+ = XArray (S.transpose perm arr)
transpose2 :: forall sh1 sh2 a.
StaticShapeX sh1 -> StaticShapeX sh2
@@ -311,18 +311,18 @@ transpose2 ssh1 ssh2 (XArray arr)
, Dict <- gknownNat (Proxy @(Rank (sh2 ++ sh1)))
, Refl <- lemRankAppComm ssh1 ssh2
, let n1 = ssxLength ssh1
- = XArray (U.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr)
+ = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr)
-sumFull :: (U.Unbox a, Num a) => XArray sh a -> a
-sumFull (XArray arr) = U.sumA arr
+sumFull :: (S.Unbox a, Num a) => XArray sh a -> a
+sumFull (XArray arr) = S.sumA arr
-sumInner :: forall sh sh' a. (U.Unbox a, Num a)
+sumInner :: forall sh sh' a. (S.Unbox a, Num a)
=> StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh a
sumInner ssh ssh'
| Refl <- lemAppNil @sh
= rerank ssh ssh' SZX (scalar . sumFull)
-sumOuter :: forall sh sh' a. (U.Unbox a, Num a)
+sumOuter :: forall sh sh' a. (S.Unbox a, Num a)
=> StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh' a
sumOuter ssh ssh'
| Refl <- lemAppNil @sh
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index bdded69..41fb1fd 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -35,8 +35,9 @@ import Data.Coerce (coerce, Coercible)
import Data.Kind
import Data.Proxy
import Data.Type.Equality
-import qualified Data.Vector.Unboxed as VU
-import qualified Data.Vector.Unboxed.Mutable as VUM
+import qualified Data.Vector.Storable as VS
+import qualified Data.Vector.Storable.Mutable as VSM
+import Foreign.Storable (Storable)
import qualified GHC.TypeLits as GHC
import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++), pattern GHC_SNat)
@@ -119,11 +120,11 @@ deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a))
type MixedVecs :: Type -> [Maybe GHC.Nat] -> Type -> Type
data family MixedVecs s sh a
-newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VU.MVector s a)
+newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a)
-newtype instance MixedVecs s sh Int = MV_Int (VU.MVector s Int)
-newtype instance MixedVecs s sh Double = MV_Double (VU.MVector s Double)
-newtype instance MixedVecs s sh () = MV_Nil (VU.MVector s ()) -- no content, MVector optimises this
+newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int)
+newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double)
+newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MVector optimises this
-- etc.
data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh b)
@@ -143,7 +144,7 @@ class Elt a where
mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a
mlift :: forall sh1 sh2. KnownShapeX sh2
- => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
-> Mixed sh1 a -> Mixed sh2 a
-- ====== PRIVATE METHODS ====== --
@@ -176,7 +177,7 @@ class Elt a where
-- Arrays of scalars are basically just arrays of scalars.
-instance VU.Unbox a => Elt (Primitive a) where
+instance Storable a => Elt (Primitive a) where
mshape (M_Primitive a) = X.shape a
mindex (M_Primitive a) i = Primitive (X.index a i)
mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i)
@@ -191,18 +192,18 @@ instance VU.Unbox a => Elt (Primitive a) where
memptyArray sh = M_Primitive (X.generate sh (error "memptyArray Int: shape was not empty"))
mvecsNumElts _ = 1
- mvecsUnsafeNew sh _ = MV_Primitive <$> VUM.unsafeNew (X.shapeSize sh)
- mvecsWrite sh i (Primitive x) (MV_Primitive v) = VUM.write v (X.toLinearIdx sh i) x
+ mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh)
+ mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x
-- TODO: this use of toVector is suboptimal
mvecsWritePartial
- :: forall sh' sh s. (KnownShapeX sh', VU.Unbox a)
+ :: forall sh' sh s. KnownShapeX sh'
=> IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do
let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr)))
- VU.copy (VUM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr)
+ VS.copy (VSM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr)
- mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VU.freeze v
+ mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VS.freeze v
deriving via Primitive Int instance Elt Int
deriving via Primitive Double instance Elt Double
@@ -247,13 +248,13 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
= M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
mlift :: forall sh1 sh2. KnownShapeX sh2
- => (forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray (sh1 ++ sh3) b -> XArray (sh2 ++ sh3) b)
+ => (forall sh3 b. (KnownShapeX sh3, Storable b) => Proxy sh3 -> XArray (sh1 ++ sh3) b -> XArray (sh2 ++ sh3) b)
-> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
mlift f (M_Nest arr)
| Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
= M_Nest (mlift f' arr)
where
- f' :: forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray ((sh1 ++ sh') ++ sh3) b -> XArray ((sh2 ++ sh') ++ sh3) b
+ f' :: forall sh3 b. (KnownShapeX sh3, Storable b) => Proxy sh3 -> XArray ((sh1 ++ sh') ++ sh3) b -> XArray ((sh2 ++ sh') ++ sh3) b
f' _
| Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @sh3)
, Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @sh3)
@@ -374,7 +375,7 @@ instance (KnownNat n, Elt a) => Elt (Ranked n a) where
mindexPartial arr i
mlift :: forall sh1 sh2. KnownShapeX sh2
- => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
-> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a)
mlift f (M_Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
@@ -469,7 +470,7 @@ instance (KnownShape sh, Elt a) => Elt (Shaped sh a) where
mindexPartial arr i
mlift :: forall sh1 sh2. KnownShapeX sh2
- => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
-> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a)
mlift f (M_Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
@@ -568,14 +569,14 @@ rgenerate sh f
= Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR))
rlift :: forall n1 n2 a. (KnownNat n2, Elt a)
- => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
+ => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
-> Ranked n1 a -> Ranked n2 a
rlift f (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n2)
= Ranked (mlift f arr)
rsumOuter1 :: forall n a.
- (VU.Unbox a, Num a, KnownNat n, forall sh. Coercible (Mixed sh a) (XArray sh a))
+ (Storable a, Num a, KnownNat n, forall sh. Coercible (Mixed sh a) (XArray sh a))
=> Ranked (S n) a -> Ranked n a
rsumOuter1 (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
@@ -636,14 +637,14 @@ sgenerate sh f
= Shaped (mgenerate (ixCvtSX sh) (f . ixCvtXS (knownShape @sh)))
slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a)
- => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
+ => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
-> Shaped sh1 a -> Shaped sh2 a
slift f (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh2)
= Shaped (mlift f arr)
ssumOuter1 :: forall sh n a.
- (VU.Unbox a, Num a, GHC.KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a))
+ (Storable a, Num a, GHC.KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a))
=> Shaped (n : sh) a -> Shaped sh a
ssumOuter1 (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)