From cde40eeb9560919fa464f14c76edc1aae1dac43b Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 14 Apr 2024 12:27:09 +0200 Subject: Fix constraints in Data.Array.Mixed These were still Unbox from before the transition to orthotope's Storable API --- src/Data/Array/Mixed.hs | 31 ++++++++++++++++--------------- src/Data/Array/Nested.hs | 3 --- 2 files changed, 16 insertions(+), 18 deletions(-) (limited to 'src/Data/Array') diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 049a0c4..040b8d7 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -17,6 +17,7 @@ import Data.Kind import Data.Proxy import Data.Type.Equality import qualified Data.Vector.Storable as VS +import Foreign.Storable (Storable) import GHC.TypeLits import Unsafe.Coerce (unsafeCoerce) @@ -205,48 +206,48 @@ shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr) go (() :$? ssh) (n : l) = n ::? go ssh l go _ _ = error "Invalid shapeL" -fromVector :: forall sh a. S.Unbox a => IxX sh -> VS.Vector a -> XArray sh a +fromVector :: forall sh a. Storable a => IxX sh -> VS.Vector a -> XArray sh a fromVector sh v | Dict <- lemKnownNatRank sh , Dict <- knownNatFromINat (Proxy @(Rank sh)) = XArray (S.fromVector (shapeLshape sh) v) -toVector :: S.Unbox a => XArray sh a -> VS.Vector a +toVector :: Storable a => XArray sh a -> VS.Vector a toVector (XArray arr) = S.toVector arr -scalar :: S.Unbox a => a -> XArray '[] a +scalar :: Storable a => a -> XArray '[] a scalar = XArray . S.scalar -unScalar :: S.Unbox a => XArray '[] a -> a +unScalar :: Storable a => XArray '[] a -> a unScalar (XArray a) = S.unScalar a -generate :: S.Unbox a => IxX sh -> (IxX sh -> a) -> XArray sh a +generate :: Storable a => IxX sh -> (IxX sh -> a) -> XArray sh a generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh) --- generateM :: (Monad m, S.Unbox a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a) +-- generateM :: (Monad m, Storable a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a) -- generateM sh f | Dict <- lemKnownNatRank sh = -- XArray . S.fromVector (shapeLshape sh) -- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh) -indexPartial :: S.Unbox a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a +indexPartial :: Storable a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a indexPartial (XArray arr) IZX = XArray arr indexPartial (XArray arr) (i ::@ idx) = indexPartial (XArray (S.index arr i)) idx indexPartial (XArray arr) (i ::? idx) = indexPartial (XArray (S.index arr i)) idx -index :: forall sh a. S.Unbox a => XArray sh a -> IxX sh -> a +index :: forall sh a. Storable a => XArray sh a -> IxX sh -> a index xarr i | Refl <- lemAppNil @sh = let XArray arr' = indexPartial xarr i :: XArray '[] a in S.unScalar arr' -append :: forall sh a. (KnownShapeX sh, S.Unbox a) => XArray sh a -> XArray sh a -> XArray sh a +append :: forall sh a. (KnownShapeX sh, Storable a) => XArray sh a -> XArray sh a -> XArray sh a append (XArray a) (XArray b) | Dict <- lemKnownNatRankSSX (knownShapeX @sh) , Dict <- knownNatFromINat (Proxy @(Rank sh)) = XArray (S.append a b) rerank :: forall sh sh1 sh2 a b. - (S.Unbox a, S.Unbox b) + (Storable a, Storable b) => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 -> (XArray sh1 a -> XArray sh2 b) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b @@ -266,14 +267,14 @@ rerank ssh ssh1 ssh2 f (XArray arr) unXArray (XArray a) = a rerankTop :: forall sh sh1 sh2 a b. - (S.Unbox a, S.Unbox b) + (Storable a, Storable b) => StaticShapeX sh1 -> StaticShapeX sh2 -> StaticShapeX sh -> (XArray sh1 a -> XArray sh2 b) -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh rerank2 :: forall sh sh1 sh2 a b c. - (S.Unbox a, S.Unbox b, S.Unbox c) + (Storable a, Storable b, Storable c) => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c @@ -313,16 +314,16 @@ transpose2 ssh1 ssh2 (XArray arr) , let n1 = ssxLength ssh1 = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) -sumFull :: (S.Unbox a, Num a) => XArray sh a -> a +sumFull :: (Storable a, Num a) => XArray sh a -> a sumFull (XArray arr) = S.sumA arr -sumInner :: forall sh sh' a. (S.Unbox a, Num a) +sumInner :: forall sh sh' a. (Storable a, Num a) => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh a sumInner ssh ssh' | Refl <- lemAppNil @sh = rerank ssh ssh' SZX (scalar . sumFull) -sumOuter :: forall sh sh' a. (S.Unbox a, Num a) +sumOuter :: forall sh sh' a. (Storable a, Num a) => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh' a sumOuter ssh ssh' | Refl <- lemAppNil @sh diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 0de3884..9feda61 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -32,11 +32,8 @@ module Data.Array.Nested ( -- * Further utilities / re-exports type (++), - VU.Unbox, ) where -import qualified Data.Vector.Unboxed as VU - import Data.Array.Mixed import Data.Array.Nested.Internal import Data.INat -- cgit v1.2.3-70-g09d2