aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2024-11-21 13:49:25 +0100
committerTom Smeding <t.j.smeding@uu.nl>2024-11-21 13:49:52 +0100
commit7e5ccb3402f97c1c7cff158147aeb863d429f885 (patch)
tree8adaa90ef98d0316638eb6e128007aa86ac06d83
parent1b69f540b0c1fa8d45b80f452cab8e7ac02dffd9 (diff)
[rsm]emptyArraysingletons
-rw-r--r--src/Data/Array/Nested.hs3
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs15
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs9
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs9
4 files changed, 24 insertions, 12 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index f987acc..b825691 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -8,6 +8,7 @@ module Data.Array.Nested (
ShR(.., ZSR, (:$:)), IShR,
rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rsumOuter1, rsumAllPrim,
rtranspose, rappend, rconcat, rscalar, rfromVector, rtoVector, runScalar,
+ remptyArray,
rrerank,
rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,
rfromListLinear, rfromListPrimLinear, rtoListLinear,
@@ -29,6 +30,7 @@ module Data.Array.Nested (
sshape, srank, ssize, sindex, sindexPartial, sgenerate, ssumOuter1, ssumAllPrim,
stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar,
-- TODO: sconcat? What should its type be?
+ semptyArray,
srerank,
sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1,
sfromListLinear, sfromListPrimLinear, stoListLinear,
@@ -50,6 +52,7 @@ module Data.Array.Nested (
SMayNat(..),
mshape, mrank, msize, mindex, mindexPartial, mgenerate, msumOuter1, msumAllPrim,
mtranspose, mappend, mconcat, mscalar, mfromVector, mtoVector, munScalar,
+ memptyArray,
mrerank,
mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1,
mfromListLinear, mfromListPrimLinear, mtoListLinear,
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index 0e4f5e6..8d239cf 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -366,7 +366,7 @@ class Elt a where
-- 'Data.Array.Nested.Shaped.sgenerate'.
class Elt a => KnownElt a where
-- | Create an empty array. The given shape must have size zero; this may or may not be checked.
- memptyArray :: IShX sh -> Mixed sh a
+ memptyArrayUnsafe :: IShX sh -> Mixed sh a
-- | Create uninitialised vectors for this array type, given the shape of
-- this vector and an example for the contents.
@@ -461,7 +461,7 @@ deriving via Primitive Float instance Elt Float
deriving via Primitive () instance Elt ()
instance Storable a => KnownElt (Primitive a) where
- memptyArray sh = M_Primitive sh (X.empty sh)
+ memptyArrayUnsafe sh = M_Primitive sh (X.empty sh)
mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh)
mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0
@@ -517,7 +517,7 @@ instance (Elt a, Elt b) => Elt (a, b) where
mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
instance (KnownElt a, KnownElt b) => KnownElt (a, b) where
- memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)
+ memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh)
mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b)
@@ -650,7 +650,7 @@ instance Elt a => Elt (Mixed sh' a) where
mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs
instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
- memptyArray sh = M_Nest sh (memptyArray (shxAppend sh (shxCompleteZeros (knownShX @sh'))))
+ memptyArrayUnsafe sh = M_Nest sh (memptyArrayUnsafe (shxAppend sh (shxCompleteZeros (knownShX @sh'))))
mvecsUnsafeNew sh example
| shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a))
@@ -661,6 +661,9 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
+memptyArray :: KnownElt a => IShX sh -> Mixed (Just 0 : sh) a
+memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh)
+
mrank :: Elt a => Mixed sh a -> SNat (Rank sh)
mrank = shxRank . mshape
@@ -687,12 +690,12 @@ msize = shxSize . mshape
-- easily, hence the runtime check.
mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a
mgenerate sh f = case shxEnum sh of
- [] -> memptyArray sh
+ [] -> memptyArrayUnsafe sh
firstidx : restidxs ->
let firstelem = f (ixxZero' sh)
shapetree = mshapeTree firstelem
in if mshapeTreeEmpty (Proxy @a) shapetree
- then memptyArray sh
+ then memptyArrayUnsafe sh
else runST $ do
vecs <- mvecsUnsafeNew sh firstelem
mvecsWrite sh firstidx firstelem vecs
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index 39a6018..9e8a7b2 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -165,11 +165,11 @@ instance Elt a => Elt (Ranked n a) where
vecs)
instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
- memptyArray :: forall sh. IShX sh -> Mixed sh (Ranked n a)
- memptyArray i
+ memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a)
+ memptyArrayUnsafe i
| Dict <- lemKnownReplicate (SNat @n)
= coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
- memptyArray i
+ memptyArrayUnsafe i
mvecsUnsafeNew idx (Ranked arr)
| Dict <- lemKnownReplicate (SNat @n)
@@ -229,6 +229,9 @@ instance (FloatElt a, NumElt a, PrimElt a, Num a) => Floating (Ranked n a) where
log1mexp = arithPromoteRanked GHC.Float.log1mexp
+remptyArray :: KnownElt a => Ranked 1 a
+remptyArray = mtoRanked (memptyArray ZSX)
+
rshape :: Elt a => Ranked n a -> IShR n
rshape (Ranked arr) = shCvtXR' (mshape arr)
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index e2f65c0..228d800 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -163,11 +163,11 @@ instance Elt a => Elt (Shaped sh a) where
vecs)
instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
- memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
- memptyArray i
+ memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
+ memptyArrayUnsafe i
| Dict <- lemKnownMapJust (Proxy @sh)
= coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
- memptyArray i
+ memptyArrayUnsafe i
mvecsUnsafeNew idx (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
@@ -239,6 +239,9 @@ instance (FloatElt a, NumElt a, PrimElt a, Floating a, KnownShS sh) => Floating
log1mexp = arithPromoteShaped GHC.Float.log1mexp
+semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a
+semptyArray sh = Shaped (memptyArray (shCvtSX sh))
+
sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
sshape (Shaped arr) = shCvtXS' (mshape arr)