diff options
Diffstat (limited to 'src/Data/Array/Nested/Ranked.hs')
| -rw-r--r-- | src/Data/Array/Nested/Ranked.hs | 82 |
1 files changed, 41 insertions, 41 deletions
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs index 5cda531..2fbfdd8 100644 --- a/src/Data/Array/Nested/Ranked.hs +++ b/src/Data/Array/Nested/Ranked.hs @@ -81,16 +81,19 @@ rlift2 :: forall n1 n2 n3 a. Elt a -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2) -rsumOuter1P :: forall n a. - (Storable a, NumElt a) - => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) -rsumOuter1P (Ranked arr) +rsumOuter1PrimP :: forall n a. + (Storable a, NumElt a) + => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) +rsumOuter1PrimP (Ranked arr) | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) - = Ranked (msumOuter1P arr) + = Ranked (msumOuter1PrimP arr) -rsumOuter1 :: forall n a. (NumElt a, PrimElt a) - => Ranked (n + 1) a -> Ranked n a -rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive +rsumOuter1Prim :: forall n a. (NumElt a, PrimElt a) + => Ranked (n + 1) a -> Ranked n a +rsumOuter1Prim = rfromPrimitive . rsumOuter1PrimP . rtoPrimitive + +rsumAllPrimP :: (Storable a, NumElt a) => Ranked n (Primitive a) -> a +rsumAllPrimP (Ranked arr) = msumAllPrimP arr rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a rsumAllPrim (Ranked arr) = msumAllPrim arr @@ -228,16 +231,14 @@ rzip = coerce mzip runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b) runzip = coerce munzip -rrerankP :: forall n1 n2 n a b. (Storable a, Storable b) - => SNat n -> IShR n2 - -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) - -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b) -rrerankP sn sh2 f (Ranked arr) - | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat)) - , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat)) - = Ranked (mrerankP (ssxFromSNat sn) (shxFromShR sh2) - (\a -> let Ranked r = f (Ranked a) in r) - arr) +rrerankPrimP :: forall n1 n2 n a b. (Storable a, Storable b) + => IShR n2 + -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) + -> Ranked n (Ranked n1 (Primitive a)) -> Ranked n (Ranked n2 (Primitive b)) +rrerankPrimP sh2 f (Ranked (M_Ranked arr)) + = Ranked (M_Ranked (mrerankPrimP (shxFromShR sh2) + (\a -> let Ranked r = f (Ranked a) in r) + arr)) -- | If there is a zero-sized dimension in the @n@-prefix of the shape of the -- input array, then there is no way to deduce the full shape of the output @@ -248,26 +249,28 @@ rrerankP sn sh2 f (Ranked arr) -- For example, if: -- -- @ --- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21] +-- arr :: Ranked 3 (Ranked 2 Int) -- outer array shape [3, 0, 4]; inner shape [2, 21] -- f :: Ranked 2 Int -> Ranked 3 Float -- @ -- -- then: -- -- @ --- rrerank _ _ _ f arr :: Ranked 6 Float +-- rrerank _ f arr :: Ranked 3 (Ranked 3 Float) -- @ -- --- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the --- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended --- to return an array with shape all-0 here (it probably didn't), but there is --- no better number to put here absent a subarray of the input to pass to @f@. -rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b) - => SNat n -> IShR n2 - -> (Ranked n1 a -> Ranked n2 b) - -> Ranked (n + n1) a -> Ranked (n + n2) b -rrerank sn sh2 f (rtoPrimitive -> arr) = - rfromPrimitive $ rrerankP sn sh2 (rtoPrimitive . f . rfromPrimitive) arr +-- and the inner arrays of the result will have shape @[0, 0, 0]@. We don't +-- know if @f@ intended to return an array with all-zero shape here (it +-- probably didn't), but there is no better number to put here absent a +-- subarray of the input to pass to @f@. +rrerankPrim :: forall n1 n2 n a b. (PrimElt a, PrimElt b) + => IShR n2 + -> (Ranked n1 a -> Ranked n2 b) + -> Ranked n (Ranked n1 a) -> Ranked n (Ranked n2 b) +rrerankPrim sh2 f (Ranked (M_Ranked arr)) = + Ranked (M_Ranked (mrerankPrim (shxFromShR sh2) + (\a -> let Ranked r = f (Ranked a) in r) + arr)) rreplicate :: forall n m a. Elt a => IShR n -> Ranked m a -> Ranked (n + m) a @@ -275,14 +278,14 @@ rreplicate sh (Ranked arr) | Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat)) = Ranked (mreplicate (shxFromShR sh) arr) -rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) -rreplicateScalP sh x +rreplicatePrimP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) +rreplicatePrimP sh x | Dict <- lemKnownReplicate (shrRank sh) - = Ranked (mreplicateScalP (shxFromShR sh) x) + = Ranked (mreplicatePrimP (shxFromShR sh) x) -rreplicateScal :: forall n a. PrimElt a +rreplicatePrim :: forall n a. PrimElt a => IShR n -> a -> Ranked n a -rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) +rreplicatePrim sh x = rfromPrimitive (rreplicatePrimP sh x) rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a rslice i n (Ranked arr) @@ -290,12 +293,9 @@ rslice i n (Ranked arr) = Ranked (msliceN i n arr) rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a -rrev1 arr = - rlift (rrank arr) - (\(_ :: StaticShX sh') -> - case lemReplicateSucc @(Nothing @Nat) (Proxy @n) of - Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) - arr +rrev1 (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = Ranked (mrev1 arr) rreshape :: forall n n' a. Elt a => IShR n' -> Ranked n a -> Ranked n' a |
