diff options
Diffstat (limited to 'src/Data/Array')
-rw-r--r-- | src/Data/Array/Nested.hs | 6 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 29 |
2 files changed, 27 insertions, 8 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 22f65a6..23d20eb 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -4,7 +4,7 @@ module Data.Array.Nested ( Ranked, IxR(..), rshape, rindex, rindexPartial, rgenerate, rsumOuter1, - rtranspose, rappend, rscalar, + rtranspose, rappend, rscalar, rfromVector, -- ** Lifting orthotope operations to 'Ranked' arrays rlift, @@ -13,7 +13,7 @@ module Data.Array.Nested ( IxS(..), KnownShape(..), SShape(..), sshape, sindex, sindexPartial, sgenerate, ssumOuter1, - stranspose, sappend, sscalar, + stranspose, sappend, sscalar, sfromVector, -- ** Lifting orthotope operations to 'Shaped' arrays slift, @@ -21,7 +21,7 @@ module Data.Array.Nested ( Mixed, IxX(..), KnownShapeX(..), StaticShapeX(..), - mgenerate, mtranspose, mappend, + mgenerate, mtranspose, mappend, mfromVector, -- * Array elements Elt(mshape, mindex, mindexPartial, mscalar, mlift, mlift2), diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index d9b2c86..c85b1fc 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -327,6 +327,12 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs +-- | Check whether a given shape corresponds on the statically-known components of the shape. +checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool +checkBounds IZX SZX = True +checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (fromSNat n') && checkBounds sh' ssh' +checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh' + -- Public method. Turns out this doesn't have to be in the type class! -- | Create an array given a size and a function that computes the element at a -- given index. @@ -351,11 +357,6 @@ mgenerate sh f forM_ (tail (X.enumShape sh)) $ \idx -> mvecsWrite sh idx (f idx) vecs mvecsFreeze sh vecs - where - checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool - checkBounds IZX SZX = True - checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (fromSNat n') && checkBounds sh' ssh' - checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh' mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a mtranspose perm = @@ -369,6 +370,13 @@ mappend = mlift2 go => Proxy sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b go Proxy | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = X.append +mfromVector :: forall sh a. (KnownShapeX sh, Storable a) => IxX sh -> VS.Vector a -> Mixed sh (Primitive a) +mfromVector sh v + | not (checkBounds sh (knownShapeX @sh)) = + error $ "mfromVector: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) + | otherwise = + M_Primitive (X.fromVector sh v) + mliftPrim :: (KnownShapeX sh, Storable a) => (a -> a) -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) @@ -719,6 +727,11 @@ rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend rscalar :: Elt a => a -> Ranked I0 a rscalar x = Ranked (mscalar x) +rfromVector :: forall n a. (KnownINat n, Storable a) => IxR n -> VS.Vector a -> Ranked n (Primitive a) +rfromVector sh v + | Dict <- lemKnownReplicate (Proxy @n) + = Ranked (mfromVector (ixCvtRX sh) v) + -- ====== API OF SHAPED ARRAYS ====== -- @@ -769,6 +782,7 @@ ixCvtSX IZS = IZX ixCvtSX (n ::$ sh) = n ::@ ixCvtSX sh +-- | This does not touch the passed array, all information comes from 'KnownShape'. sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IxS sh sshape _ = cvtSShapeIxS (knownShape @sh) @@ -815,3 +829,8 @@ sappend | Dict <- lemKnownMapJust (Proxy @sh) = coerce mappend sscalar :: Elt a => a -> Shaped '[] a sscalar x = Shaped (mscalar x) + +sfromVector :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a) +sfromVector v + | Dict <- lemKnownMapJust (Proxy @sh) + = Shaped (mfromVector (ixCvtSX (cvtSShapeIxS (knownShape @sh))) v) |