summaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r--src/Data/Array/Mixed.hs31
1 files changed, 16 insertions, 15 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 049a0c4..040b8d7 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -17,6 +17,7 @@ import Data.Kind
import Data.Proxy
import Data.Type.Equality
import qualified Data.Vector.Storable as VS
+import Foreign.Storable (Storable)
import GHC.TypeLits
import Unsafe.Coerce (unsafeCoerce)
@@ -205,48 +206,48 @@ shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr)
go (() :$? ssh) (n : l) = n ::? go ssh l
go _ _ = error "Invalid shapeL"
-fromVector :: forall sh a. S.Unbox a => IxX sh -> VS.Vector a -> XArray sh a
+fromVector :: forall sh a. Storable a => IxX sh -> VS.Vector a -> XArray sh a
fromVector sh v
| Dict <- lemKnownNatRank sh
, Dict <- knownNatFromINat (Proxy @(Rank sh))
= XArray (S.fromVector (shapeLshape sh) v)
-toVector :: S.Unbox a => XArray sh a -> VS.Vector a
+toVector :: Storable a => XArray sh a -> VS.Vector a
toVector (XArray arr) = S.toVector arr
-scalar :: S.Unbox a => a -> XArray '[] a
+scalar :: Storable a => a -> XArray '[] a
scalar = XArray . S.scalar
-unScalar :: S.Unbox a => XArray '[] a -> a
+unScalar :: Storable a => XArray '[] a -> a
unScalar (XArray a) = S.unScalar a
-generate :: S.Unbox a => IxX sh -> (IxX sh -> a) -> XArray sh a
+generate :: Storable a => IxX sh -> (IxX sh -> a) -> XArray sh a
generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh)
--- generateM :: (Monad m, S.Unbox a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a)
+-- generateM :: (Monad m, Storable a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a)
-- generateM sh f | Dict <- lemKnownNatRank sh =
-- XArray . S.fromVector (shapeLshape sh)
-- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh)
-indexPartial :: S.Unbox a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a
+indexPartial :: Storable a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a
indexPartial (XArray arr) IZX = XArray arr
indexPartial (XArray arr) (i ::@ idx) = indexPartial (XArray (S.index arr i)) idx
indexPartial (XArray arr) (i ::? idx) = indexPartial (XArray (S.index arr i)) idx
-index :: forall sh a. S.Unbox a => XArray sh a -> IxX sh -> a
+index :: forall sh a. Storable a => XArray sh a -> IxX sh -> a
index xarr i
| Refl <- lemAppNil @sh
= let XArray arr' = indexPartial xarr i :: XArray '[] a
in S.unScalar arr'
-append :: forall sh a. (KnownShapeX sh, S.Unbox a) => XArray sh a -> XArray sh a -> XArray sh a
+append :: forall sh a. (KnownShapeX sh, Storable a) => XArray sh a -> XArray sh a -> XArray sh a
append (XArray a) (XArray b)
| Dict <- lemKnownNatRankSSX (knownShapeX @sh)
, Dict <- knownNatFromINat (Proxy @(Rank sh))
= XArray (S.append a b)
rerank :: forall sh sh1 sh2 a b.
- (S.Unbox a, S.Unbox b)
+ (Storable a, Storable b)
=> StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2
-> (XArray sh1 a -> XArray sh2 b)
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b
@@ -266,14 +267,14 @@ rerank ssh ssh1 ssh2 f (XArray arr)
unXArray (XArray a) = a
rerankTop :: forall sh sh1 sh2 a b.
- (S.Unbox a, S.Unbox b)
+ (Storable a, Storable b)
=> StaticShapeX sh1 -> StaticShapeX sh2 -> StaticShapeX sh
-> (XArray sh1 a -> XArray sh2 b)
-> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b
rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh
rerank2 :: forall sh sh1 sh2 a b c.
- (S.Unbox a, S.Unbox b, S.Unbox c)
+ (Storable a, Storable b, Storable c)
=> StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2
-> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c)
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c
@@ -313,16 +314,16 @@ transpose2 ssh1 ssh2 (XArray arr)
, let n1 = ssxLength ssh1
= XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr)
-sumFull :: (S.Unbox a, Num a) => XArray sh a -> a
+sumFull :: (Storable a, Num a) => XArray sh a -> a
sumFull (XArray arr) = S.sumA arr
-sumInner :: forall sh sh' a. (S.Unbox a, Num a)
+sumInner :: forall sh sh' a. (Storable a, Num a)
=> StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh a
sumInner ssh ssh'
| Refl <- lemAppNil @sh
= rerank ssh ssh' SZX (scalar . sumFull)
-sumOuter :: forall sh sh' a. (S.Unbox a, Num a)
+sumOuter :: forall sh sh' a. (Storable a, Num a)
=> StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh' a
sumOuter ssh ssh'
| Refl <- lemAppNil @sh