aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Ranked.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Ranked.hs')
-rw-r--r--src/Data/Array/Nested/Ranked.hs82
1 files changed, 41 insertions, 41 deletions
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
index 5cda531..2fbfdd8 100644
--- a/src/Data/Array/Nested/Ranked.hs
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -81,16 +81,19 @@ rlift2 :: forall n1 n2 n3 a. Elt a
-> Ranked n1 a -> Ranked n2 a -> Ranked n3 a
rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2)
-rsumOuter1P :: forall n a.
- (Storable a, NumElt a)
- => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
-rsumOuter1P (Ranked arr)
+rsumOuter1PrimP :: forall n a.
+ (Storable a, NumElt a)
+ => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
+rsumOuter1PrimP (Ranked arr)
| Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
- = Ranked (msumOuter1P arr)
+ = Ranked (msumOuter1PrimP arr)
-rsumOuter1 :: forall n a. (NumElt a, PrimElt a)
- => Ranked (n + 1) a -> Ranked n a
-rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive
+rsumOuter1Prim :: forall n a. (NumElt a, PrimElt a)
+ => Ranked (n + 1) a -> Ranked n a
+rsumOuter1Prim = rfromPrimitive . rsumOuter1PrimP . rtoPrimitive
+
+rsumAllPrimP :: (Storable a, NumElt a) => Ranked n (Primitive a) -> a
+rsumAllPrimP (Ranked arr) = msumAllPrimP arr
rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a
rsumAllPrim (Ranked arr) = msumAllPrim arr
@@ -228,16 +231,14 @@ rzip = coerce mzip
runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b)
runzip = coerce munzip
-rrerankP :: forall n1 n2 n a b. (Storable a, Storable b)
- => SNat n -> IShR n2
- -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b))
- -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b)
-rrerankP sn sh2 f (Ranked arr)
- | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat))
- , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat))
- = Ranked (mrerankP (ssxFromSNat sn) (shxFromShR sh2)
- (\a -> let Ranked r = f (Ranked a) in r)
- arr)
+rrerankPrimP :: forall n1 n2 n a b. (Storable a, Storable b)
+ => IShR n2
+ -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b))
+ -> Ranked n (Ranked n1 (Primitive a)) -> Ranked n (Ranked n2 (Primitive b))
+rrerankPrimP sh2 f (Ranked (M_Ranked arr))
+ = Ranked (M_Ranked (mrerankPrimP (shxFromShR sh2)
+ (\a -> let Ranked r = f (Ranked a) in r)
+ arr))
-- | If there is a zero-sized dimension in the @n@-prefix of the shape of the
-- input array, then there is no way to deduce the full shape of the output
@@ -248,26 +249,28 @@ rrerankP sn sh2 f (Ranked arr)
-- For example, if:
--
-- @
--- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21]
+-- arr :: Ranked 3 (Ranked 2 Int) -- outer array shape [3, 0, 4]; inner shape [2, 21]
-- f :: Ranked 2 Int -> Ranked 3 Float
-- @
--
-- then:
--
-- @
--- rrerank _ _ _ f arr :: Ranked 6 Float
+-- rrerank _ f arr :: Ranked 3 (Ranked 3 Float)
-- @
--
--- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the
--- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended
--- to return an array with shape all-0 here (it probably didn't), but there is
--- no better number to put here absent a subarray of the input to pass to @f@.
-rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b)
- => SNat n -> IShR n2
- -> (Ranked n1 a -> Ranked n2 b)
- -> Ranked (n + n1) a -> Ranked (n + n2) b
-rrerank sn sh2 f (rtoPrimitive -> arr) =
- rfromPrimitive $ rrerankP sn sh2 (rtoPrimitive . f . rfromPrimitive) arr
+-- and the inner arrays of the result will have shape @[0, 0, 0]@. We don't
+-- know if @f@ intended to return an array with all-zero shape here (it
+-- probably didn't), but there is no better number to put here absent a
+-- subarray of the input to pass to @f@.
+rrerankPrim :: forall n1 n2 n a b. (PrimElt a, PrimElt b)
+ => IShR n2
+ -> (Ranked n1 a -> Ranked n2 b)
+ -> Ranked n (Ranked n1 a) -> Ranked n (Ranked n2 b)
+rrerankPrim sh2 f (Ranked (M_Ranked arr)) =
+ Ranked (M_Ranked (mrerankPrim (shxFromShR sh2)
+ (\a -> let Ranked r = f (Ranked a) in r)
+ arr))
rreplicate :: forall n m a. Elt a
=> IShR n -> Ranked m a -> Ranked (n + m) a
@@ -275,14 +278,14 @@ rreplicate sh (Ranked arr)
| Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat))
= Ranked (mreplicate (shxFromShR sh) arr)
-rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
-rreplicateScalP sh x
+rreplicatePrimP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
+rreplicatePrimP sh x
| Dict <- lemKnownReplicate (shrRank sh)
- = Ranked (mreplicateScalP (shxFromShR sh) x)
+ = Ranked (mreplicatePrimP (shxFromShR sh) x)
-rreplicateScal :: forall n a. PrimElt a
+rreplicatePrim :: forall n a. PrimElt a
=> IShR n -> a -> Ranked n a
-rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)
+rreplicatePrim sh x = rfromPrimitive (rreplicatePrimP sh x)
rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
rslice i n (Ranked arr)
@@ -290,12 +293,9 @@ rslice i n (Ranked arr)
= Ranked (msliceN i n arr)
rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a
-rrev1 arr =
- rlift (rrank arr)
- (\(_ :: StaticShX sh') ->
- case lemReplicateSucc @(Nothing @Nat) (Proxy @n) of
- Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh'))
- arr
+rrev1 (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = Ranked (mrev1 arr)
rreshape :: forall n n' a. Elt a
=> IShR n' -> Ranked n a -> Ranked n' a