aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-04-14 16:17:15 +0200
committerTom Smeding <tom@tomsmeding.com>2024-04-14 16:17:15 +0200
commit754526c1ed56d3eb10106af3e9981863ef8c9d0b (patch)
tree56bd48c55eaa2fa61b37a6699ddef3f4ea11fa67
parent4070876a20afbb0c6bc11fb0a12ee17f8febc047 (diff)
fromVector
-rw-r--r--src/Data/Array/Nested.hs6
-rw-r--r--src/Data/Array/Nested/Internal.hs29
2 files changed, 27 insertions, 8 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 22f65a6..23d20eb 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -4,7 +4,7 @@ module Data.Array.Nested (
Ranked,
IxR(..),
rshape, rindex, rindexPartial, rgenerate, rsumOuter1,
- rtranspose, rappend, rscalar,
+ rtranspose, rappend, rscalar, rfromVector,
-- ** Lifting orthotope operations to 'Ranked' arrays
rlift,
@@ -13,7 +13,7 @@ module Data.Array.Nested (
IxS(..),
KnownShape(..), SShape(..),
sshape, sindex, sindexPartial, sgenerate, ssumOuter1,
- stranspose, sappend, sscalar,
+ stranspose, sappend, sscalar, sfromVector,
-- ** Lifting orthotope operations to 'Shaped' arrays
slift,
@@ -21,7 +21,7 @@ module Data.Array.Nested (
Mixed,
IxX(..),
KnownShapeX(..), StaticShapeX(..),
- mgenerate, mtranspose, mappend,
+ mgenerate, mtranspose, mappend, mfromVector,
-- * Array elements
Elt(mshape, mindex, mindexPartial, mscalar, mlift, mlift2),
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index d9b2c86..c85b1fc 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -327,6 +327,12 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs
+-- | Check whether a given shape corresponds on the statically-known components of the shape.
+checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool
+checkBounds IZX SZX = True
+checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (fromSNat n') && checkBounds sh' ssh'
+checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'
+
-- Public method. Turns out this doesn't have to be in the type class!
-- | Create an array given a size and a function that computes the element at a
-- given index.
@@ -351,11 +357,6 @@ mgenerate sh f
forM_ (tail (X.enumShape sh)) $ \idx ->
mvecsWrite sh idx (f idx) vecs
mvecsFreeze sh vecs
- where
- checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool
- checkBounds IZX SZX = True
- checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (fromSNat n') && checkBounds sh' ssh'
- checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'
mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a
mtranspose perm =
@@ -369,6 +370,13 @@ mappend = mlift2 go
=> Proxy sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b
go Proxy | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = X.append
+mfromVector :: forall sh a. (KnownShapeX sh, Storable a) => IxX sh -> VS.Vector a -> Mixed sh (Primitive a)
+mfromVector sh v
+ | not (checkBounds sh (knownShapeX @sh)) =
+ error $ "mfromVector: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh)
+ | otherwise =
+ M_Primitive (X.fromVector sh v)
+
mliftPrim :: (KnownShapeX sh, Storable a)
=> (a -> a)
-> Mixed sh (Primitive a) -> Mixed sh (Primitive a)
@@ -719,6 +727,11 @@ rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend
rscalar :: Elt a => a -> Ranked I0 a
rscalar x = Ranked (mscalar x)
+rfromVector :: forall n a. (KnownINat n, Storable a) => IxR n -> VS.Vector a -> Ranked n (Primitive a)
+rfromVector sh v
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = Ranked (mfromVector (ixCvtRX sh) v)
+
-- ====== API OF SHAPED ARRAYS ====== --
@@ -769,6 +782,7 @@ ixCvtSX IZS = IZX
ixCvtSX (n ::$ sh) = n ::@ ixCvtSX sh
+-- | This does not touch the passed array, all information comes from 'KnownShape'.
sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IxS sh
sshape _ = cvtSShapeIxS (knownShape @sh)
@@ -815,3 +829,8 @@ sappend | Dict <- lemKnownMapJust (Proxy @sh) = coerce mappend
sscalar :: Elt a => a -> Shaped '[] a
sscalar x = Shaped (mscalar x)
+
+sfromVector :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a)
+sfromVector v
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = Shaped (mfromVector (ixCvtSX (cvtSShapeIxS (knownShape @sh))) v)