diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-04-14 16:17:15 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-04-14 16:17:15 +0200 | 
| commit | 754526c1ed56d3eb10106af3e9981863ef8c9d0b (patch) | |
| tree | 56bd48c55eaa2fa61b37a6699ddef3f4ea11fa67 /src | |
| parent | 4070876a20afbb0c6bc11fb0a12ee17f8febc047 (diff) | |
fromVector
Diffstat (limited to 'src')
| -rw-r--r-- | src/Data/Array/Nested.hs | 6 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 29 | 
2 files changed, 27 insertions, 8 deletions
| diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 22f65a6..23d20eb 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -4,7 +4,7 @@ module Data.Array.Nested (    Ranked,    IxR(..),    rshape, rindex, rindexPartial, rgenerate, rsumOuter1, -  rtranspose, rappend, rscalar, +  rtranspose, rappend, rscalar, rfromVector,    -- ** Lifting orthotope operations to 'Ranked' arrays    rlift, @@ -13,7 +13,7 @@ module Data.Array.Nested (    IxS(..),    KnownShape(..), SShape(..),    sshape, sindex, sindexPartial, sgenerate, ssumOuter1, -  stranspose, sappend, sscalar, +  stranspose, sappend, sscalar, sfromVector,    -- ** Lifting orthotope operations to 'Shaped' arrays    slift, @@ -21,7 +21,7 @@ module Data.Array.Nested (    Mixed,    IxX(..),    KnownShapeX(..), StaticShapeX(..), -  mgenerate, mtranspose, mappend, +  mgenerate, mtranspose, mappend, mfromVector,    -- * Array elements    Elt(mshape, mindex, mindexPartial, mscalar, mlift, mlift2), diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index d9b2c86..c85b1fc 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -327,6 +327,12 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where    mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs +-- | Check whether a given shape corresponds on the statically-known components of the shape. +checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool +checkBounds IZX SZX = True +checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (fromSNat n') && checkBounds sh' ssh' +checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh' +  -- Public method. Turns out this doesn't have to be in the type class!  -- | Create an array given a size and a function that computes the element at a  -- given index. @@ -351,11 +357,6 @@ mgenerate sh f                    forM_ (tail (X.enumShape sh)) $ \idx ->                      mvecsWrite sh idx (f idx) vecs                    mvecsFreeze sh vecs -  where -    checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool -    checkBounds IZX SZX = True -    checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (fromSNat n') && checkBounds sh' ssh' -    checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'  mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a  mtranspose perm = @@ -369,6 +370,13 @@ mappend = mlift2 go             => Proxy sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b          go Proxy | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = X.append +mfromVector :: forall sh a. (KnownShapeX sh, Storable a) => IxX sh -> VS.Vector a -> Mixed sh (Primitive a) +mfromVector sh v +  | not (checkBounds sh (knownShapeX @sh)) = +      error $ "mfromVector: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) +  | otherwise = +      M_Primitive (X.fromVector sh v) +  mliftPrim :: (KnownShapeX sh, Storable a)            => (a -> a)            -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) @@ -719,6 +727,11 @@ rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend  rscalar :: Elt a => a -> Ranked I0 a  rscalar x = Ranked (mscalar x) +rfromVector :: forall n a. (KnownINat n, Storable a) => IxR n -> VS.Vector a -> Ranked n (Primitive a) +rfromVector sh v +  | Dict <- lemKnownReplicate (Proxy @n) +  = Ranked (mfromVector (ixCvtRX sh) v) +  -- ====== API OF SHAPED ARRAYS ====== -- @@ -769,6 +782,7 @@ ixCvtSX IZS = IZX  ixCvtSX (n ::$ sh) = n ::@ ixCvtSX sh +-- | This does not touch the passed array, all information comes from 'KnownShape'.  sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IxS sh  sshape _ = cvtSShapeIxS (knownShape @sh) @@ -815,3 +829,8 @@ sappend | Dict <- lemKnownMapJust (Proxy @sh) = coerce mappend  sscalar :: Elt a => a -> Shaped '[] a  sscalar x = Shaped (mscalar x) + +sfromVector :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a) +sfromVector v +  | Dict <- lemKnownMapJust (Proxy @sh) +  = Shaped (mfromVector (ixCvtSX (cvtSShapeIxS (knownShape @sh))) v) | 
