diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-04-14 16:17:15 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-04-14 16:17:15 +0200 | 
| commit | 754526c1ed56d3eb10106af3e9981863ef8c9d0b (patch) | |
| tree | 56bd48c55eaa2fa61b37a6699ddef3f4ea11fa67 /src/Data/Array/Nested | |
| parent | 4070876a20afbb0c6bc11fb0a12ee17f8febc047 (diff) | |
fromVector
Diffstat (limited to 'src/Data/Array/Nested')
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 29 | 
1 files changed, 24 insertions, 5 deletions
| diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index d9b2c86..c85b1fc 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -327,6 +327,12 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where    mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs +-- | Check whether a given shape corresponds on the statically-known components of the shape. +checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool +checkBounds IZX SZX = True +checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (fromSNat n') && checkBounds sh' ssh' +checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh' +  -- Public method. Turns out this doesn't have to be in the type class!  -- | Create an array given a size and a function that computes the element at a  -- given index. @@ -351,11 +357,6 @@ mgenerate sh f                    forM_ (tail (X.enumShape sh)) $ \idx ->                      mvecsWrite sh idx (f idx) vecs                    mvecsFreeze sh vecs -  where -    checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool -    checkBounds IZX SZX = True -    checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (fromSNat n') && checkBounds sh' ssh' -    checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'  mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a  mtranspose perm = @@ -369,6 +370,13 @@ mappend = mlift2 go             => Proxy sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b          go Proxy | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = X.append +mfromVector :: forall sh a. (KnownShapeX sh, Storable a) => IxX sh -> VS.Vector a -> Mixed sh (Primitive a) +mfromVector sh v +  | not (checkBounds sh (knownShapeX @sh)) = +      error $ "mfromVector: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) +  | otherwise = +      M_Primitive (X.fromVector sh v) +  mliftPrim :: (KnownShapeX sh, Storable a)            => (a -> a)            -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) @@ -719,6 +727,11 @@ rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend  rscalar :: Elt a => a -> Ranked I0 a  rscalar x = Ranked (mscalar x) +rfromVector :: forall n a. (KnownINat n, Storable a) => IxR n -> VS.Vector a -> Ranked n (Primitive a) +rfromVector sh v +  | Dict <- lemKnownReplicate (Proxy @n) +  = Ranked (mfromVector (ixCvtRX sh) v) +  -- ====== API OF SHAPED ARRAYS ====== -- @@ -769,6 +782,7 @@ ixCvtSX IZS = IZX  ixCvtSX (n ::$ sh) = n ::@ ixCvtSX sh +-- | This does not touch the passed array, all information comes from 'KnownShape'.  sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IxS sh  sshape _ = cvtSShapeIxS (knownShape @sh) @@ -815,3 +829,8 @@ sappend | Dict <- lemKnownMapJust (Proxy @sh) = coerce mappend  sscalar :: Elt a => a -> Shaped '[] a  sscalar x = Shaped (mscalar x) + +sfromVector :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a) +sfromVector v +  | Dict <- lemKnownMapJust (Proxy @sh) +  = Shaped (mfromVector (ixCvtSX (cvtSShapeIxS (knownShape @sh))) v) | 
