aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-04-14 10:12:41 +0200
committerTom Smeding <tom@tomsmeding.com>2024-04-14 10:12:41 +0200
commit3e74b0673caba7c04353c0cedb1d6e02de1fd007 (patch)
tree36e700e86199dd07b54cd3c28e8d09dc77f32c42 /src/Data/Array/Mixed.hs
parent8a81f7ea9eed9afaec948910caaf0a5c498de6c6 (diff)
Move from unboxed to storable vectors
Mikolaj requires this to interface with hmatrix
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r--src/Data/Array/Mixed.hs72
1 files changed, 36 insertions, 36 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 2875203..2bbf81d 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -12,11 +12,11 @@
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed where
-import qualified Data.Array.RankedU as U
+import qualified Data.Array.RankedS as S
import Data.Kind
import Data.Proxy
import Data.Type.Equality
-import qualified Data.Vector.Unboxed as VU
+import qualified Data.Vector.Storable as VS
import qualified GHC.TypeLits as GHC
import Unsafe.Coerce (unsafeCoerce)
@@ -77,7 +77,7 @@ type family Rank sh where
Rank (_ : sh) = S (Rank sh)
type XArray :: [Maybe GHC.Nat] -> Type -> Type
-data XArray sh a = XArray (U.Array (GNat (Rank sh)) a)
+data XArray sh a = XArray (S.Array (GNat (Rank sh)) a)
deriving (Show)
zeroIdx :: StaticShapeX sh -> IxX sh
@@ -149,7 +149,7 @@ enumShape = \sh -> go sh id []
go (n ::@ sh) f = foldr (.) id [go sh (f . (i ::@)) | i <- [0 .. n-1]]
go (n ::? sh) f = foldr (.) id [go sh (f . (i ::?)) | i <- [0 .. n-1]]
-shapeLshape :: IxX sh -> U.ShapeL
+shapeLshape :: IxX sh -> S.ShapeL
shapeLshape IZX = []
shapeLshape (n ::@ sh) = n : shapeLshape sh
shapeLshape (n ::? sh) = n : shapeLshape sh
@@ -197,7 +197,7 @@ lemAppKnownShapeX (() :$? ssh) ssh'
= Dict
shape :: forall sh a. KnownShapeX sh => XArray sh a -> IxX sh
-shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr)
+shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr)
where
go :: StaticShapeX sh' -> [Int] -> IxX sh'
go SZX [] = IZX
@@ -205,48 +205,48 @@ shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr)
go (() :$? ssh) (n : l) = n ::? go ssh l
go _ _ = error "Invalid shapeL"
-fromVector :: forall sh a. U.Unbox a => IxX sh -> VU.Vector a -> XArray sh a
+fromVector :: forall sh a. S.Unbox a => IxX sh -> VS.Vector a -> XArray sh a
fromVector sh v
| Dict <- lemKnownNatRank sh
, Dict <- gknownNat (Proxy @(Rank sh))
- = XArray (U.fromVector (shapeLshape sh) v)
+ = XArray (S.fromVector (shapeLshape sh) v)
-toVector :: U.Unbox a => XArray sh a -> VU.Vector a
-toVector (XArray arr) = U.toVector arr
+toVector :: S.Unbox a => XArray sh a -> VS.Vector a
+toVector (XArray arr) = S.toVector arr
-scalar :: U.Unbox a => a -> XArray '[] a
-scalar = XArray . U.scalar
+scalar :: S.Unbox a => a -> XArray '[] a
+scalar = XArray . S.scalar
-unScalar :: U.Unbox a => XArray '[] a -> a
-unScalar (XArray a) = U.unScalar a
+unScalar :: S.Unbox a => XArray '[] a -> a
+unScalar (XArray a) = S.unScalar a
-generate :: U.Unbox a => IxX sh -> (IxX sh -> a) -> XArray sh a
-generate sh f = fromVector sh $ VU.generate (shapeSize sh) (f . fromLinearIdx sh)
+generate :: S.Unbox a => IxX sh -> (IxX sh -> a) -> XArray sh a
+generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh)
--- generateM :: (Monad m, U.Unbox a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a)
+-- generateM :: (Monad m, S.Unbox a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a)
-- generateM sh f | Dict <- lemKnownNatRank sh =
--- XArray . U.fromVector (shapeLshape sh)
--- <$> VU.generateM (shapeSize sh) (f . fromLinearIdx sh)
+-- XArray . S.fromVector (shapeLshape sh)
+-- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh)
-indexPartial :: U.Unbox a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a
+indexPartial :: S.Unbox a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a
indexPartial (XArray arr) IZX = XArray arr
-indexPartial (XArray arr) (i ::@ idx) = indexPartial (XArray (U.index arr i)) idx
-indexPartial (XArray arr) (i ::? idx) = indexPartial (XArray (U.index arr i)) idx
+indexPartial (XArray arr) (i ::@ idx) = indexPartial (XArray (S.index arr i)) idx
+indexPartial (XArray arr) (i ::? idx) = indexPartial (XArray (S.index arr i)) idx
-index :: forall sh a. U.Unbox a => XArray sh a -> IxX sh -> a
+index :: forall sh a. S.Unbox a => XArray sh a -> IxX sh -> a
index xarr i
| Refl <- lemAppNil @sh
= let XArray arr' = indexPartial xarr i :: XArray '[] a
- in U.unScalar arr'
+ in S.unScalar arr'
-append :: forall sh a. (KnownShapeX sh, U.Unbox a) => XArray sh a -> XArray sh a -> XArray sh a
+append :: forall sh a. (KnownShapeX sh, S.Unbox a) => XArray sh a -> XArray sh a -> XArray sh a
append (XArray a) (XArray b)
| Dict <- lemKnownNatRankSSX (knownShapeX @sh)
, Dict <- gknownNat (Proxy @(Rank sh))
- = XArray (U.append a b)
+ = XArray (S.append a b)
rerank :: forall sh sh1 sh2 a b.
- (U.Unbox a, U.Unbox b)
+ (S.Unbox a, S.Unbox b)
=> StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2
-> (XArray sh1 a -> XArray sh2 b)
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b
@@ -259,21 +259,21 @@ rerank ssh ssh1 ssh2 f (XArray arr)
, Refl <- lemRankApp ssh ssh2
, Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
, Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
- = XArray (U.rerank @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2))
+ = XArray (S.rerank @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2))
(\a -> unXArray (f (XArray a)))
arr)
where
unXArray (XArray a) = a
rerankTop :: forall sh sh1 sh2 a b.
- (U.Unbox a, U.Unbox b)
+ (S.Unbox a, S.Unbox b)
=> StaticShapeX sh1 -> StaticShapeX sh2 -> StaticShapeX sh
-> (XArray sh1 a -> XArray sh2 b)
-> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b
rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh
rerank2 :: forall sh sh1 sh2 a b c.
- (U.Unbox a, U.Unbox b, U.Unbox c)
+ (S.Unbox a, S.Unbox b, S.Unbox c)
=> StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2
-> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c)
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c
@@ -286,7 +286,7 @@ rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2)
, Refl <- lemRankApp ssh ssh2
, Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
, Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
- = XArray (U.rerank2 @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2))
+ = XArray (S.rerank2 @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2))
(\a b -> unXArray (f (XArray a) (XArray b)))
arr1 arr2)
where
@@ -297,7 +297,7 @@ transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a
transpose perm (XArray arr)
| Dict <- lemKnownNatRankSSX (knownShapeX @sh)
, Dict <- gknownNat (Proxy @(Rank sh))
- = XArray (U.transpose perm arr)
+ = XArray (S.transpose perm arr)
transpose2 :: forall sh1 sh2 a.
StaticShapeX sh1 -> StaticShapeX sh2
@@ -311,18 +311,18 @@ transpose2 ssh1 ssh2 (XArray arr)
, Dict <- gknownNat (Proxy @(Rank (sh2 ++ sh1)))
, Refl <- lemRankAppComm ssh1 ssh2
, let n1 = ssxLength ssh1
- = XArray (U.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr)
+ = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr)
-sumFull :: (U.Unbox a, Num a) => XArray sh a -> a
-sumFull (XArray arr) = U.sumA arr
+sumFull :: (S.Unbox a, Num a) => XArray sh a -> a
+sumFull (XArray arr) = S.sumA arr
-sumInner :: forall sh sh' a. (U.Unbox a, Num a)
+sumInner :: forall sh sh' a. (S.Unbox a, Num a)
=> StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh a
sumInner ssh ssh'
| Refl <- lemAppNil @sh
= rerank ssh ssh' SZX (scalar . sumFull)
-sumOuter :: forall sh sh' a. (U.Unbox a, Num a)
+sumOuter :: forall sh sh' a. (S.Unbox a, Num a)
=> StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh' a
sumOuter ssh ssh'
| Refl <- lemAppNil @sh