diff options
Diffstat (limited to 'src/Data/Array/Nested/Ranked.hs')
| -rw-r--r-- | src/Data/Array/Nested/Ranked.hs | 182 |
1 files changed, 108 insertions, 74 deletions
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs index e5b8970..d687983 100644 --- a/src/Data/Array/Nested/Ranked.hs +++ b/src/Data/Array/Nested/Ranked.hs @@ -49,9 +49,11 @@ remptyArray = mtoRanked (memptyArray ZSX) rsize :: Elt a => Ranked n a -> Int rsize = shrSize . rshape +{-# INLINEABLE rindex #-} rindex :: Elt a => Ranked n a -> IIxR n -> a rindex (Ranked arr) idx = mindex arr (ixxFromIxR idx) +{-# INLINEABLE rindexPartial #-} rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a rindexPartial (Ranked arr) idx = Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) @@ -59,7 +61,8 @@ rindexPartial (Ranked arr) idx = (ixxFromIxR idx)) -- | __WARNING__: All values returned from the function must have equal shape. --- See the documentation of 'mgenerate' for more details. +-- See the documentation of 'mgenerate' for more details; see also +-- 'rgeneratePrim'. rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a rgenerate sh f | sn@SNat <- shrRank sh @@ -67,6 +70,14 @@ rgenerate sh f , Refl <- lemRankReplicate sn = Ranked (mgenerate (shxFromShR sh) (f . ixrFromIxX)) +-- | See 'mgeneratePrim'. +{-# INLINE rgeneratePrim #-} +rgeneratePrim :: forall n a i. (PrimElt a, Num i) + => IShR n -> (IxR n i -> a) -> Ranked n a +rgeneratePrim sh f = + let g i = f (ixrFromLinear sh i) + in rfromVector sh $ VS.generate (shrSize sh) g + -- | See the documentation of 'mlift'. rlift :: forall n1 n2 a. Elt a => SNat n2 @@ -81,16 +92,19 @@ rlift2 :: forall n1 n2 n3 a. Elt a -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2) -rsumOuter1P :: forall n a. - (Storable a, NumElt a) - => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) -rsumOuter1P (Ranked arr) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = Ranked (msumOuter1P arr) +rsumOuter1PrimP :: forall n a. + (Storable a, NumElt a) + => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) +rsumOuter1PrimP (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = Ranked (msumOuter1PrimP arr) + +rsumOuter1Prim :: forall n a. (NumElt a, PrimElt a) + => Ranked (n + 1) a -> Ranked n a +rsumOuter1Prim = rfromPrimitive . rsumOuter1PrimP . rtoPrimitive -rsumOuter1 :: forall n a. (NumElt a, PrimElt a) - => Ranked (n + 1) a -> Ranked n a -rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive +rsumAllPrimP :: (Storable a, NumElt a) => Ranked n (Primitive a) -> a +rsumAllPrimP (Ranked arr) = msumAllPrimP arr rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a rsumAllPrim (Ranked arr) = msumAllPrim arr @@ -108,7 +122,7 @@ rtranspose perm arr rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a rconcat - | Refl <- lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = coerce mconcat rappend :: forall n a. Elt a @@ -116,7 +130,7 @@ rappend :: forall n a. Elt a rappend arr1 arr2 | sn@SNat <- rrank arr1 , Dict <- lemKnownReplicate sn - , Refl <- lemReplicateSucc @(Nothing @Nat) @n + , Refl <- lemReplicateSucc @(Nothing @Nat) (SNat @n) = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) arr1 arr2 @@ -137,38 +151,63 @@ rtoVectorP = coerce mtoVectorP rtoVector :: PrimElt a => Ranked n a -> VS.Vector a rtoVector = coerce mtoVector +-- | All arrays in the list, even subarrays inside @a@, must have the same +-- shape; if they do not, a runtime error will be thrown. See the +-- documentation of 'mgenerate' for more information about this restriction. +-- +-- Because the length of the 'NonEmpty' list is unknown, its spine must be +-- materialised in memory in order to compute its length. If its length is +-- already known, use 'rfromListOuterN' to be able to stream the list. +-- +-- If your array is 1-dimensional and contains scalars, use 'rfromList1Prim'. rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a rfromListOuter l - | Refl <- lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) +-- | See 'rfromListOuter'. If the list does not have the given length, a +-- runtime error is thrown. 'rfromList1PrimN' is faster if applicable. +rfromListOuterN :: forall n a. Elt a => Int -> NonEmpty (Ranked n a) -> Ranked (n + 1) a +rfromListOuterN n l + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = Ranked (mfromListOuterN n (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) + +-- | Because the length of the 'NonEmpty' list is unknown, its spine must be +-- materialised in memory in order to compute its length. If its length is +-- already known, use 'rfromList1N' to be able to stream the list. +-- +-- If the elements are scalars, 'rfromList1Prim' is faster. rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a -rfromList1 l = Ranked (mfromList1 l) +rfromList1 = coerce mfromList1 -rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a -rfromList1Prim l = Ranked (mfromList1Prim l) +-- | If the elements are scalars, 'rfromList1PrimN' is faster. A runtime error +-- is thrown if the list length does not match the given length. +rfromList1N :: Elt a => Int -> NonEmpty a -> Ranked 1 a +rfromList1N = coerce mfromList1N -rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] -rtoListOuter (Ranked arr) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) +-- | If the elements are scalars, 'rfromListPrimLinear' is faster. +rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a +rfromListLinear sh l = Ranked (mfromListLinear (shxFromShR sh) l) -rtoList1 :: Elt a => Ranked 1 a -> [a] -rtoList1 = map runScalar . rtoListOuter +-- | Because the length of the list is unknown, its spine must be materialised +-- in memory in order to compute its length. If its length is already known, +-- use 'rfromList1PrimN' to be able to stream the list. +rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a +rfromList1Prim = coerce mfromList1Prim -rfromListPrim :: PrimElt a => [a] -> Ranked 1 a -rfromListPrim l = - let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l - in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr +rfromList1PrimN :: PrimElt a => Int -> [a] -> Ranked 1 a +rfromList1PrimN = coerce mfromList1PrimN -rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a -rfromListPrimLinear sh l = - let M_Primitive _ xarr = toPrimitive (mfromListPrim l) - in Ranked $ fromPrimitive $ M_Primitive (shxFromShR sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShR sh) xarr) +rfromListPrimLinear :: forall n a. PrimElt a => IShR n -> [a] -> Ranked n a +rfromListPrimLinear sh l = Ranked (mfromListPrimLinear (shxFromShR sh) l) -rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a -rfromListLinear sh l = rreshape sh (rfromList1 l) +rtoList :: Elt a => Ranked 1 a -> [a] +rtoList = map runScalar . rtoListOuter + +rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] +rtoListOuter (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) rtoListLinear :: Elt a => Ranked n a -> [a] rtoListLinear (Ranked arr) = mtoListLinear arr @@ -179,9 +218,9 @@ rfromOrthotope sn arr = let xarr = XArray arr in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr)) -rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a +rtoOrthotope :: forall a n. PrimElt a => Ranked n a -> S.Array n a rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr))) - | Refl <- lemRankReplicate (shrRank $ shrFromShX2 sh) + | Refl <- lemRankReplicate (shrRank $ shrFromShX2 @n sh) = arr runScalar :: Elt a => Ranked 0 a -> a @@ -197,22 +236,20 @@ runNest rarr@(Ranked (M_Ranked (M_Nest _ arr))) | Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat)) = Ranked arr -rzip :: Ranked n a -> Ranked n b -> Ranked n (a, b) +rzip :: (Elt a, Elt b) => Ranked n a -> Ranked n b -> Ranked n (a, b) rzip = coerce mzip runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b) runzip = coerce munzip -rrerankP :: forall n1 n2 n a b. (Storable a, Storable b) - => SNat n -> IShR n2 - -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) - -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b) -rrerankP sn sh2 f (Ranked arr) - | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat)) - , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat)) - = Ranked (mrerankP (ssxFromSNat sn) (shxFromShR sh2) - (\a -> let Ranked r = f (Ranked a) in r) - arr) +rrerankPrimP :: forall n1 n2 n a b. (Storable a, Storable b) + => IShR n2 + -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) + -> Ranked n (Ranked n1 (Primitive a)) -> Ranked n (Ranked n2 (Primitive b)) +rrerankPrimP sh2 f (Ranked (M_Ranked arr)) + = Ranked (M_Ranked (mrerankPrimP (shxFromShR sh2) + (\a -> let Ranked r = f (Ranked a) in r) + arr)) -- | If there is a zero-sized dimension in the @n@-prefix of the shape of the -- input array, then there is no way to deduce the full shape of the output @@ -223,26 +260,28 @@ rrerankP sn sh2 f (Ranked arr) -- For example, if: -- -- @ --- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21] +-- arr :: Ranked 3 (Ranked 2 Int) -- outer array shape [3, 0, 4]; inner shape [2, 21] -- f :: Ranked 2 Int -> Ranked 3 Float -- @ -- -- then: -- -- @ --- rrerank _ _ _ f arr :: Ranked 5 Float +-- rrerank _ f arr :: Ranked 3 (Ranked 3 Float) -- @ -- --- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the --- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended --- to return an array with shape all-0 here (it probably didn't), but there is --- no better number to put here absent a subarray of the input to pass to @f@. -rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b) - => SNat n -> IShR n2 - -> (Ranked n1 a -> Ranked n2 b) - -> Ranked (n + n1) a -> Ranked (n + n2) b -rrerank sn sh2 f (rtoPrimitive -> arr) = - rfromPrimitive $ rrerankP sn sh2 (rtoPrimitive . f . rfromPrimitive) arr +-- and the inner arrays of the result will have shape @[0, 0, 0]@. We don't +-- know if @f@ intended to return an array with all-zero shape here (it +-- probably didn't), but there is no better number to put here absent a +-- subarray of the input to pass to @f@. +rrerankPrim :: forall n1 n2 n a b. (PrimElt a, PrimElt b) + => IShR n2 + -> (Ranked n1 a -> Ranked n2 b) + -> Ranked n (Ranked n1 a) -> Ranked n (Ranked n2 b) +rrerankPrim sh2 f (Ranked (M_Ranked arr)) = + Ranked (M_Ranked (mrerankPrim (shxFromShR sh2) + (\a -> let Ranked r = f (Ranked a) in r) + arr)) rreplicate :: forall n m a. Elt a => IShR n -> Ranked m a -> Ranked (n + m) a @@ -250,29 +289,24 @@ rreplicate sh (Ranked arr) | Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat)) = Ranked (mreplicate (shxFromShR sh) arr) -rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) -rreplicateScalP sh x +rreplicatePrimP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) +rreplicatePrimP sh x | Dict <- lemKnownReplicate (shrRank sh) - = Ranked (mreplicateScalP (shxFromShR sh) x) + = Ranked (mreplicatePrimP (shxFromShR sh) x) -rreplicateScal :: forall n a. PrimElt a +rreplicatePrim :: forall n a. PrimElt a => IShR n -> a -> Ranked n a -rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) +rreplicatePrim sh x = rfromPrimitive (rreplicatePrimP sh x) rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a -rslice i n arr - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = rlift (rrank arr) - (\_ -> X.sliceU i n) - arr +rslice i n (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = Ranked (msliceN i n arr) rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a -rrev1 arr = - rlift (rrank arr) - (\(_ :: StaticShX sh') -> - case lemReplicateSucc @(Nothing @Nat) @n of - Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) - arr +rrev1 (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = Ranked (mrev1 arr) rreshape :: forall n n' a. Elt a => IShR n' -> Ranked n a -> Ranked n' a |
