aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Nested.hs24
-rw-r--r--src/Data/Array/Nested/Mixed.hs90
-rw-r--r--src/Data/Array/Nested/Ranked.hs82
-rw-r--r--src/Data/Array/Nested/Ranked/Base.hs4
-rw-r--r--src/Data/Array/Nested/Shaped.hs58
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs4
-rw-r--r--src/Data/Array/Nested/Trace.hs2
7 files changed, 150 insertions, 114 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 8fb3bd1..e849673 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -6,13 +6,13 @@ module Data.Array.Nested (
ListR(ZR, (:::)),
IxR(.., ZIR, (:.:)), IIxR,
ShR(.., ZSR, (:$:)), IShR,
- rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rsumOuter1, rsumAllPrim,
+ rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rsumOuter1Prim, rsumAllPrim,
rtranspose, rappend, rconcat, rscalar, rfromVector, rtoVector, runScalar,
remptyArray,
- rrerank,
- rreplicate, rreplicateScal,
- rfromList1, rfromList1N,
+ rrerankPrim,
+ rreplicate, rreplicatePrim,
rfromListOuter, rfromListOuterN,
+ rfromList1, rfromList1N,
rfromListLinear,
rfromList1Prim, rfromList1PrimN,
rfromListPrimLinear,
@@ -36,13 +36,13 @@ module Data.Array.Nested (
ListS(ZS, (::$)),
IxS(.., ZIS, (:.$)), IIxS,
ShS(.., ZSS, (:$$)), KnownShS(..),
- sshape, srank, ssize, sindex, sindexPartial, sgenerate, ssumOuter1, ssumAllPrim,
+ sshape, srank, ssize, sindex, sindexPartial, sgenerate, ssumOuter1Prim, ssumAllPrim,
stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar,
-- TODO: sconcat? What should its type be?
semptyArray,
- srerank,
- sreplicate, sreplicateScal,
- sfromList1, sfromListOuter, sfromListLinear, sfromList1Prim, sfromListPrimLinear,
+ srerankPrim,
+ sreplicate, sreplicatePrim,
+ sfromListOuter, sfromList1, sfromListLinear, sfromList1Prim, sfromListPrimLinear,
stoList, stoListOuter, stoListLinear,
sslice, srev1, sreshape, sflatten, siota,
sminIndexPrim, smaxIndexPrim, sdot1Inner, sdot,
@@ -65,13 +65,13 @@ module Data.Array.Nested (
ShX(.., ZSX, (:$%)), KnownShX(..), IShX,
StaticShX(.., ZKX, (:!%)),
SMayNat(..),
- mshape, mrank, msize, mindex, mindexPartial, mgenerate, msumOuter1, msumAllPrim,
+ mshape, mrank, msize, mindex, mindexPartial, mgenerate, msumOuter1Prim, msumAllPrim,
mtranspose, mappend, mconcat, mscalar, mfromVector, mtoVector, munScalar,
memptyArray,
- mrerank,
- mreplicate, mreplicateScal,
- mfromList1, mfromList1N, mfromList1SN,
+ mrerankPrim,
+ mreplicate, mreplicatePrim,
mfromListOuter, mfromListOuterN, mfromListOuterSN,
+ mfromList1, mfromList1N, mfromList1SN,
mfromListLinear,
mfromList1Prim, mfromList1PrimN, mfromList1PrimSN,
mfromListPrimLinear,
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index a2787b8..ecc4479 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -698,6 +698,7 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
+-- TODO: should we provide a function that's just memptyArrayUnsafe but with a size==0 check? That may save someone a transpose somewhere
memptyArray :: KnownElt a => IShX sh -> Mixed (Just 0 : sh) a
memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh)
@@ -745,18 +746,21 @@ mgenerate sh f = case shxEnum sh of
mvecsWrite sh idx val vecs
mvecsFreeze sh vecs
-msumOuter1P :: forall sh n a. (Storable a, NumElt a)
- => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
-msumOuter1P (M_Primitive (n :$% sh) arr) =
+msumOuter1PrimP :: forall sh n a. (Storable a, NumElt a)
+ => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
+msumOuter1PrimP (M_Primitive (n :$% sh) arr) =
let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX
in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr)
-msumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
- => Mixed (n : sh) a -> Mixed sh a
-msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive
+msumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a)
+ => Mixed (n : sh) a -> Mixed sh a
+msumOuter1Prim = fromPrimitive . msumOuter1PrimP @sh @n @a . toPrimitive
+
+msumAllPrimP :: (Storable a, NumElt a) => Mixed sh (Primitive a) -> a
+msumAllPrimP (M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr
msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a
-msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr
+msumAllPrim arr = msumAllPrimP (toPrimitive arr)
mappend :: forall n m sh a. Elt a
=> Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a
@@ -883,24 +887,54 @@ mzip a b
munzip :: Mixed sh (a, b) -> (Mixed sh a, Mixed sh b)
munzip (M_Tup2 a b) = (a, b)
-mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
- => StaticShX sh -> IShX sh2
- -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))
- -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b)
-mrerankP ssh sh2 f (M_Primitive sh arr) =
- let sh1 = shxDropSSX ssh sh
- in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) ssh sh) sh2)
- (X.rerank ssh (ssxFromShX sh1) (ssxFromShX sh2)
- (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
- arr)
+mrerankPrimP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
+ => IShX sh2
+ -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))
+ -> Mixed sh (Mixed sh1 (Primitive a)) -> Mixed sh (Mixed sh2 (Primitive b))
+mrerankPrimP sh2 f (M_Nest sh (M_Primitive shsh1 arr)) =
+ let sh1 = shxDropSh sh shsh1
+ in M_Nest sh $
+ M_Primitive (shxAppend sh sh2)
+ (X.rerank (ssxFromShX sh) (ssxFromShX sh1) (ssxFromShX sh2)
+ (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
+ arr)
--- | See the caveats at 'Data.Array.XArray.rerank'.
-mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
- => StaticShX sh -> IShX sh2
- -> (Mixed sh1 a -> Mixed sh2 b)
- -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b
-mrerank ssh sh2 f (toPrimitive -> arr) =
- fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr
+-- | If the shape of the outer array (@sh@) is empty (i.e. contains a zero),
+-- then there is no way to deduce the full shape of the output array (more
+-- precisely, the @sh2@ part): that could only come from calling @f@, and there
+-- are no subarrays to call @f@ on. @orthotope@ errors out in this case; we
+-- choose to fill the shape with zeros wherever we cannot deduce what it should
+-- be.
+--
+-- For example, if:
+--
+-- @
+-- -- arr has shape [3, 0, 4] and the inner arrays have shape [2, 21]
+-- arr :: Mixed '[Just 3, Just 0, Just 4] (Mixed '[Just 2, Nothing] Int)
+-- f :: Mixed '[Just 2, Nothing] Int -> Mixed '[Just 5, Nothing, Just 17] Float
+-- @
+--
+-- then:
+--
+-- @
+-- mrerankPrim _ f arr :: Mixed '[Just 3, Just 0, Just 4] (Mixed '[Just 5, Nothing, Just 17] Float)
+-- @
+--
+-- and the inner arrays of the result will have shape @[5, 0, 17]@. Note the
+-- @0@ in this shape: we don't know if @f@ intended to return an array with
+-- shape 0 here (it probably didn't), but there is no better number to put here
+-- absent a subarray of the input to pass to @f@.
+--
+-- In this particular case the fact that @sh@ is empty was evident from the
+-- type-level information, but the same situation occurs when @sh@ consists of
+-- @Nothing@s, and some of those happen to be zero at runtime.
+mrerankPrim :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
+ => IShX sh2
+ -> (Mixed sh1 a -> Mixed sh2 b)
+ -> Mixed sh (Mixed sh1 a) -> Mixed sh (Mixed sh2 b)
+mrerankPrim sh2 f (M_Nest sh arr) =
+ let M_Nest sh' arr' = mrerankPrimP sh2 (toPrimitive . f . fromPrimitive) (M_Nest sh (toPrimitive arr))
+ in M_Nest sh' (fromPrimitive arr')
mreplicate :: forall sh sh' a. Elt a
=> IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a
@@ -912,12 +946,12 @@ mreplicate sh arr =
Refl -> X.replicate sh (ssxAppend ssh' sshT))
arr
-mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
-mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x)
+mreplicatePrimP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
+mreplicatePrimP sh x = M_Primitive sh (X.replicateScal sh x)
-mreplicateScal :: forall sh a. PrimElt a
+mreplicatePrim :: forall sh a. PrimElt a
=> IShX sh -> a -> Mixed sh a
-mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)
+mreplicatePrim sh x = fromPrimitive (mreplicatePrimP sh x)
msliceN :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
msliceN i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
index 5cda531..2fbfdd8 100644
--- a/src/Data/Array/Nested/Ranked.hs
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -81,16 +81,19 @@ rlift2 :: forall n1 n2 n3 a. Elt a
-> Ranked n1 a -> Ranked n2 a -> Ranked n3 a
rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2)
-rsumOuter1P :: forall n a.
- (Storable a, NumElt a)
- => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
-rsumOuter1P (Ranked arr)
+rsumOuter1PrimP :: forall n a.
+ (Storable a, NumElt a)
+ => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
+rsumOuter1PrimP (Ranked arr)
| Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
- = Ranked (msumOuter1P arr)
+ = Ranked (msumOuter1PrimP arr)
-rsumOuter1 :: forall n a. (NumElt a, PrimElt a)
- => Ranked (n + 1) a -> Ranked n a
-rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive
+rsumOuter1Prim :: forall n a. (NumElt a, PrimElt a)
+ => Ranked (n + 1) a -> Ranked n a
+rsumOuter1Prim = rfromPrimitive . rsumOuter1PrimP . rtoPrimitive
+
+rsumAllPrimP :: (Storable a, NumElt a) => Ranked n (Primitive a) -> a
+rsumAllPrimP (Ranked arr) = msumAllPrimP arr
rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a
rsumAllPrim (Ranked arr) = msumAllPrim arr
@@ -228,16 +231,14 @@ rzip = coerce mzip
runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b)
runzip = coerce munzip
-rrerankP :: forall n1 n2 n a b. (Storable a, Storable b)
- => SNat n -> IShR n2
- -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b))
- -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b)
-rrerankP sn sh2 f (Ranked arr)
- | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat))
- , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat))
- = Ranked (mrerankP (ssxFromSNat sn) (shxFromShR sh2)
- (\a -> let Ranked r = f (Ranked a) in r)
- arr)
+rrerankPrimP :: forall n1 n2 n a b. (Storable a, Storable b)
+ => IShR n2
+ -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b))
+ -> Ranked n (Ranked n1 (Primitive a)) -> Ranked n (Ranked n2 (Primitive b))
+rrerankPrimP sh2 f (Ranked (M_Ranked arr))
+ = Ranked (M_Ranked (mrerankPrimP (shxFromShR sh2)
+ (\a -> let Ranked r = f (Ranked a) in r)
+ arr))
-- | If there is a zero-sized dimension in the @n@-prefix of the shape of the
-- input array, then there is no way to deduce the full shape of the output
@@ -248,26 +249,28 @@ rrerankP sn sh2 f (Ranked arr)
-- For example, if:
--
-- @
--- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21]
+-- arr :: Ranked 3 (Ranked 2 Int) -- outer array shape [3, 0, 4]; inner shape [2, 21]
-- f :: Ranked 2 Int -> Ranked 3 Float
-- @
--
-- then:
--
-- @
--- rrerank _ _ _ f arr :: Ranked 6 Float
+-- rrerank _ f arr :: Ranked 3 (Ranked 3 Float)
-- @
--
--- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the
--- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended
--- to return an array with shape all-0 here (it probably didn't), but there is
--- no better number to put here absent a subarray of the input to pass to @f@.
-rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b)
- => SNat n -> IShR n2
- -> (Ranked n1 a -> Ranked n2 b)
- -> Ranked (n + n1) a -> Ranked (n + n2) b
-rrerank sn sh2 f (rtoPrimitive -> arr) =
- rfromPrimitive $ rrerankP sn sh2 (rtoPrimitive . f . rfromPrimitive) arr
+-- and the inner arrays of the result will have shape @[0, 0, 0]@. We don't
+-- know if @f@ intended to return an array with all-zero shape here (it
+-- probably didn't), but there is no better number to put here absent a
+-- subarray of the input to pass to @f@.
+rrerankPrim :: forall n1 n2 n a b. (PrimElt a, PrimElt b)
+ => IShR n2
+ -> (Ranked n1 a -> Ranked n2 b)
+ -> Ranked n (Ranked n1 a) -> Ranked n (Ranked n2 b)
+rrerankPrim sh2 f (Ranked (M_Ranked arr)) =
+ Ranked (M_Ranked (mrerankPrim (shxFromShR sh2)
+ (\a -> let Ranked r = f (Ranked a) in r)
+ arr))
rreplicate :: forall n m a. Elt a
=> IShR n -> Ranked m a -> Ranked (n + m) a
@@ -275,14 +278,14 @@ rreplicate sh (Ranked arr)
| Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat))
= Ranked (mreplicate (shxFromShR sh) arr)
-rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
-rreplicateScalP sh x
+rreplicatePrimP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
+rreplicatePrimP sh x
| Dict <- lemKnownReplicate (shrRank sh)
- = Ranked (mreplicateScalP (shxFromShR sh) x)
+ = Ranked (mreplicatePrimP (shxFromShR sh) x)
-rreplicateScal :: forall n a. PrimElt a
+rreplicatePrim :: forall n a. PrimElt a
=> IShR n -> a -> Ranked n a
-rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)
+rreplicatePrim sh x = rfromPrimitive (rreplicatePrimP sh x)
rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
rslice i n (Ranked arr)
@@ -290,12 +293,9 @@ rslice i n (Ranked arr)
= Ranked (msliceN i n arr)
rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a
-rrev1 arr =
- rlift (rrank arr)
- (\(_ :: StaticShX sh') ->
- case lemReplicateSucc @(Nothing @Nat) (Proxy @n) of
- Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh'))
- arr
+rrev1 (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = Ranked (mrev1 arr)
rreshape :: forall n n' a. Elt a
=> IShR n' -> Ranked n a -> Ranked n' a
diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs
index e8aba3f..9df508d 100644
--- a/src/Data/Array/Nested/Ranked/Base.hs
+++ b/src/Data/Array/Nested/Ranked/Base.hs
@@ -179,10 +179,10 @@ instance Elt a => Elt (Ranked n a) where
instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a)
- memptyArrayUnsafe i
+ memptyArrayUnsafe sh
| Dict <- lemKnownReplicate (SNat @n)
= coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
- memptyArrayUnsafe i
+ memptyArrayUnsafe sh
mvecsUnsafeNew idx (Ranked arr)
| Dict <- lemKnownReplicate (SNat @n)
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index 4a3ed8d..e635f03 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -84,13 +84,16 @@ slift2 :: forall sh1 sh2 sh3 a. Elt a
-> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a
slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShX (shxFromShS sh3)) f arr1 arr2)
-ssumOuter1P :: forall sh n a. (Storable a, NumElt a)
- => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
-ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr)
+ssumOuter1PrimP :: forall sh n a. (Storable a, NumElt a)
+ => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
+ssumOuter1PrimP (Shaped arr) = Shaped (msumOuter1PrimP arr)
-ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
- => Shaped (n : sh) a -> Shaped sh a
-ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive
+ssumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a)
+ => Shaped (n : sh) a -> Shaped sh a
+ssumOuter1Prim = sfromPrimitive . ssumOuter1PrimP . stoPrimitive
+
+ssumAllPrimP :: (PrimElt a, NumElt a) => Shaped n (Primitive a) -> a
+ssumAllPrimP (Shaped arr) = msumAllPrimP arr
ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a
ssumAllPrim (Shaped arr) = msumAllPrim arr
@@ -191,36 +194,35 @@ szip = coerce mzip
sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b)
sunzip = coerce munzip
-srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
- => ShS sh -> ShS sh2
- -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b))
- -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b)
-srerankP sh sh2 f sarr@(Shaped arr)
- | Refl <- lemMapJustApp sh (Proxy @sh1)
- , Refl <- lemMapJustApp sh (Proxy @sh2)
- = Shaped (mrerankP (ssxFromShX (shxTakeSSX (Proxy @(MapJust sh1)) (ssxFromShX (shxFromShS sh)) (shxFromShS (sshape sarr))))
- (shxFromShS sh2)
- (\a -> let Shaped r = f (Shaped a) in r)
- arr)
+srerankPrimP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
+ => ShS sh2
+ -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b))
+ -> Shaped sh (Shaped sh1 (Primitive a)) -> Shaped sh (Shaped sh2 (Primitive b))
+srerankPrimP sh2 f (Shaped (M_Shaped arr))
+ = Shaped (M_Shaped (mrerankPrimP (shxFromShS sh2)
+ (\a -> let Shaped r = f (Shaped a) in r)
+ arr))
--- | See the caveats at 'Data.Array.XArray.rerank'.
-srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
- => ShS sh -> ShS sh2
- -> (Shaped sh1 a -> Shaped sh2 b)
- -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b
-srerank sh sh2 f (stoPrimitive -> arr) =
- sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr
+-- | See the caveats at 'mrerankPrim'.
+srerankPrim :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
+ => ShS sh2
+ -> (Shaped sh1 a -> Shaped sh2 b)
+ -> Shaped sh (Shaped sh1 a) -> Shaped sh (Shaped sh2 b)
+srerankPrim sh2 f (Shaped (M_Shaped arr)) =
+ Shaped (M_Shaped (mrerankPrim (shxFromShS sh2)
+ (\a -> let Shaped r = f (Shaped a) in r)
+ arr))
sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a
sreplicate sh (Shaped arr)
| Refl <- lemMapJustApp sh (Proxy @sh')
= Shaped (mreplicate (shxFromShS sh) arr)
-sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
-sreplicateScalP sh x = Shaped (mreplicateScalP (shxFromShS sh) x)
+sreplicatePrimP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
+sreplicatePrimP sh x = Shaped (mreplicatePrimP (shxFromShS sh) x)
-sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a
-sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x)
+sreplicatePrim :: forall sh a. PrimElt a => ShS sh -> a -> Shaped sh a
+sreplicatePrim sh x = sfromPrimitive (sreplicatePrimP sh x)
sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a
sslice i n@SNat arr =
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
index b313b2d..e70d84e 100644
--- a/src/Data/Array/Nested/Shaped/Base.hs
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -172,10 +172,10 @@ instance Elt a => Elt (Shaped sh a) where
instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
- memptyArrayUnsafe i
+ memptyArrayUnsafe sh
| Dict <- lemKnownMapJust (Proxy @sh)
= coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
- memptyArrayUnsafe i
+ memptyArrayUnsafe sh
mvecsUnsafeNew idx (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
diff --git a/src/Data/Array/Nested/Trace.hs b/src/Data/Array/Nested/Trace.hs
index 6a2890f..37c31ea 100644
--- a/src/Data/Array/Nested/Trace.hs
+++ b/src/Data/Array/Nested/Trace.hs
@@ -69,4 +69,4 @@ import Data.Array.Nested.Trace.TH
$(concat <$> mapM convertFun
- ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromList1, 'rfromList1N, 'rfromListOuter, 'rfromListOuterN, 'rfromListLinear, 'rfromList1Prim, 'rfromList1PrimN, 'rfromListPrimLinear, 'rtoList, 'rtoListOuter, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromList1, 'sfromListOuter, 'sfromListLinear, 'sfromList1Prim, 'sfromListPrimLinear, 'stoList, 'stoListOuter, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromList1, 'mfromList1N, 'mfromList1SN, 'mfromListOuter, 'mfromListOuterN, 'mfromListOuterSN, 'mfromListLinear, 'mfromList1Prim, 'mfromList1PrimN, 'mfromList1PrimSN, 'mfromListPrimLinear, 'mtoList, 'mtoListOuter, 'mtoListLinear, 'msliceN, 'msliceSN, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'convert, 'mquotArray, 'mremArray, 'matan2Array])
+ ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1Prim, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerankPrim, 'rreplicate, 'rreplicatePrim, 'rfromListOuter, 'rfromListOuterN, 'rfromList1, 'rfromList1N, 'rfromListLinear, 'rfromList1Prim, 'rfromList1PrimN, 'rfromListPrimLinear, 'rtoList, 'rtoListOuter, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1Prim, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerankPrim, 'sreplicate, 'sreplicatePrim, 'sfromListOuter, 'sfromList1, 'sfromListLinear, 'sfromList1Prim, 'sfromListPrimLinear, 'stoList, 'stoListOuter, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1Prim, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerankPrim, 'mreplicate, 'mreplicatePrim, 'mfromListOuter, 'mfromListOuterN, 'mfromListOuterSN, 'mfromList1, 'mfromList1N, 'mfromList1SN, 'mfromListLinear, 'mfromList1Prim, 'mfromList1PrimN, 'mfromList1PrimSN, 'mfromListPrimLinear, 'mtoList, 'mtoListOuter, 'mtoListLinear, 'msliceN, 'msliceSN, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'convert, 'mquotArray, 'mremArray, 'matan2Array])