diff options
| -rw-r--r-- | src/Data/Array/Nested.hs | 3 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 15 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 9 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 9 | 
4 files changed, 24 insertions, 12 deletions
| diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index f987acc..b825691 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -8,6 +8,7 @@ module Data.Array.Nested (    ShR(.., ZSR, (:$:)), IShR,    rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rsumOuter1, rsumAllPrim,    rtranspose, rappend, rconcat, rscalar, rfromVector, rtoVector, runScalar, +  remptyArray,    rrerank,    rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,    rfromListLinear, rfromListPrimLinear, rtoListLinear, @@ -29,6 +30,7 @@ module Data.Array.Nested (    sshape, srank, ssize, sindex, sindexPartial, sgenerate, ssumOuter1, ssumAllPrim,    stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar,    -- TODO: sconcat? What should its type be? +  semptyArray,    srerank,    sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1,    sfromListLinear, sfromListPrimLinear, stoListLinear, @@ -50,6 +52,7 @@ module Data.Array.Nested (    SMayNat(..),    mshape, mrank, msize, mindex, mindexPartial, mgenerate, msumOuter1, msumAllPrim,    mtranspose, mappend, mconcat, mscalar, mfromVector, mtoVector, munScalar, +  memptyArray,    mrerank,    mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1,    mfromListLinear, mfromListPrimLinear, mtoListLinear, diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index 0e4f5e6..8d239cf 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -366,7 +366,7 @@ class Elt a where  -- 'Data.Array.Nested.Shaped.sgenerate'.  class Elt a => KnownElt a where    -- | Create an empty array. The given shape must have size zero; this may or may not be checked. -  memptyArray :: IShX sh -> Mixed sh a +  memptyArrayUnsafe :: IShX sh -> Mixed sh a    -- | Create uninitialised vectors for this array type, given the shape of    -- this vector and an example for the contents. @@ -461,7 +461,7 @@ deriving via Primitive Float instance Elt Float  deriving via Primitive () instance Elt ()  instance Storable a => KnownElt (Primitive a) where -  memptyArray sh = M_Primitive sh (X.empty sh) +  memptyArrayUnsafe sh = M_Primitive sh (X.empty sh)    mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh)    mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 @@ -517,7 +517,7 @@ instance (Elt a, Elt b) => Elt (a, b) where    mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b  instance (KnownElt a, KnownElt b) => KnownElt (a, b) where -  memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) +  memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh)    mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y    mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) @@ -650,7 +650,7 @@ instance Elt a => Elt (Mixed sh' a) where    mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs  instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where -  memptyArray sh = M_Nest sh (memptyArray (shxAppend sh (shxCompleteZeros (knownShX @sh')))) +  memptyArrayUnsafe sh = M_Nest sh (memptyArrayUnsafe (shxAppend sh (shxCompleteZeros (knownShX @sh'))))    mvecsUnsafeNew sh example      | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) @@ -661,6 +661,9 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where    mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) +memptyArray :: KnownElt a => IShX sh -> Mixed (Just 0 : sh) a +memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh) +  mrank :: Elt a => Mixed sh a -> SNat (Rank sh)  mrank = shxRank . mshape @@ -687,12 +690,12 @@ msize = shxSize . mshape  -- easily, hence the runtime check.  mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a  mgenerate sh f = case shxEnum sh of -  [] -> memptyArray sh +  [] -> memptyArrayUnsafe sh    firstidx : restidxs ->      let firstelem = f (ixxZero' sh)          shapetree = mshapeTree firstelem      in if mshapeTreeEmpty (Proxy @a) shapetree -         then memptyArray sh +         then memptyArrayUnsafe sh           else runST $ do                  vecs <- mvecsUnsafeNew sh firstelem                  mvecsWrite sh firstidx firstelem vecs diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index 39a6018..9e8a7b2 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -165,11 +165,11 @@ instance Elt a => Elt (Ranked n a) where                      vecs)  instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where -  memptyArray :: forall sh. IShX sh -> Mixed sh (Ranked n a) -  memptyArray i +  memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a) +  memptyArrayUnsafe i      | Dict <- lemKnownReplicate (SNat @n)      = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ -        memptyArray i +        memptyArrayUnsafe i    mvecsUnsafeNew idx (Ranked arr)      | Dict <- lemKnownReplicate (SNat @n) @@ -229,6 +229,9 @@ instance (FloatElt a, NumElt a, PrimElt a, Num a) => Floating (Ranked n a) where    log1mexp = arithPromoteRanked GHC.Float.log1mexp +remptyArray :: KnownElt a => Ranked 1 a +remptyArray = mtoRanked (memptyArray ZSX) +  rshape :: Elt a => Ranked n a -> IShR n  rshape (Ranked arr) = shCvtXR' (mshape arr) diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index e2f65c0..228d800 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -163,11 +163,11 @@ instance Elt a => Elt (Shaped sh a) where                      vecs)  instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where -  memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) -  memptyArray i +  memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) +  memptyArrayUnsafe i      | Dict <- lemKnownMapJust (Proxy @sh)      = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ -        memptyArray i +        memptyArrayUnsafe i    mvecsUnsafeNew idx (Shaped arr)      | Dict <- lemKnownMapJust (Proxy @sh) @@ -239,6 +239,9 @@ instance (FloatElt a, NumElt a, PrimElt a, Floating a, KnownShS sh) => Floating    log1mexp = arithPromoteShaped GHC.Float.log1mexp +semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a +semptyArray sh = Shaped (memptyArray (shCvtSX sh)) +  sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh  sshape (Shaped arr) = shCvtXS' (mshape arr) | 
