aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r--src/Data/Array/Nested/Internal.hs29
1 files changed, 24 insertions, 5 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index d9b2c86..c85b1fc 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -327,6 +327,12 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs
+-- | Check whether a given shape corresponds on the statically-known components of the shape.
+checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool
+checkBounds IZX SZX = True
+checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (fromSNat n') && checkBounds sh' ssh'
+checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'
+
-- Public method. Turns out this doesn't have to be in the type class!
-- | Create an array given a size and a function that computes the element at a
-- given index.
@@ -351,11 +357,6 @@ mgenerate sh f
forM_ (tail (X.enumShape sh)) $ \idx ->
mvecsWrite sh idx (f idx) vecs
mvecsFreeze sh vecs
- where
- checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool
- checkBounds IZX SZX = True
- checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (fromSNat n') && checkBounds sh' ssh'
- checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'
mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a
mtranspose perm =
@@ -369,6 +370,13 @@ mappend = mlift2 go
=> Proxy sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b
go Proxy | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = X.append
+mfromVector :: forall sh a. (KnownShapeX sh, Storable a) => IxX sh -> VS.Vector a -> Mixed sh (Primitive a)
+mfromVector sh v
+ | not (checkBounds sh (knownShapeX @sh)) =
+ error $ "mfromVector: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh)
+ | otherwise =
+ M_Primitive (X.fromVector sh v)
+
mliftPrim :: (KnownShapeX sh, Storable a)
=> (a -> a)
-> Mixed sh (Primitive a) -> Mixed sh (Primitive a)
@@ -719,6 +727,11 @@ rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend
rscalar :: Elt a => a -> Ranked I0 a
rscalar x = Ranked (mscalar x)
+rfromVector :: forall n a. (KnownINat n, Storable a) => IxR n -> VS.Vector a -> Ranked n (Primitive a)
+rfromVector sh v
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = Ranked (mfromVector (ixCvtRX sh) v)
+
-- ====== API OF SHAPED ARRAYS ====== --
@@ -769,6 +782,7 @@ ixCvtSX IZS = IZX
ixCvtSX (n ::$ sh) = n ::@ ixCvtSX sh
+-- | This does not touch the passed array, all information comes from 'KnownShape'.
sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IxS sh
sshape _ = cvtSShapeIxS (knownShape @sh)
@@ -815,3 +829,8 @@ sappend | Dict <- lemKnownMapJust (Proxy @sh) = coerce mappend
sscalar :: Elt a => a -> Shaped '[] a
sscalar x = Shaped (mscalar x)
+
+sfromVector :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a)
+sfromVector v
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = Shaped (mfromVector (ixCvtSX (cvtSShapeIxS (knownShape @sh))) v)