{-# LANGUAGE DataKinds #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Nested.Ranked ( module Data.Array.Nested.Ranked.Base, module Data.Array.Nested.Ranked, ) where import Prelude hiding (mappend, mconcat) import Data.Array.RankedS qualified as S import Data.Bifunctor (first) import Data.Coerce (coerce) import Data.List.NonEmpty (NonEmpty) import Data.Proxy import Data.Type.Equality import Data.Vector.Storable qualified as VS import Foreign.Storable (Storable) import GHC.TypeLits import GHC.TypeNats qualified as TN import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation import Data.Array.Mixed.Types import Data.Array.XArray (XArray(..)) import Data.Array.XArray qualified as X import Data.Array.Nested.Convert import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Ranked.Base import Data.Array.Nested.Ranked.Shape import Data.Array.Strided.Arith remptyArray :: KnownElt a => Ranked 1 a remptyArray = mtoRanked (memptyArray ZSX) -- | The total number of elements in the array. rsize :: Elt a => Ranked n a -> Int rsize = shrSize . rshape rindex :: Elt a => Ranked n a -> IIxR n -> a rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a rindexPartial (Ranked arr) idx = Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) (castWith (subst2 (lemReplicatePlusApp (ixrRank idx) (Proxy @m) (Proxy @Nothing))) arr) (ixCvtRX idx)) -- | __WARNING__: All values returned from the function must have equal shape. -- See the documentation of 'mgenerate' for more details. rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a rgenerate sh f | sn@SNat <- shrRank sh , Dict <- lemKnownReplicate sn , Refl <- lemRankReplicate sn = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR)) -- | See the documentation of 'mlift'. rlift :: forall n1 n2 a. Elt a => SNat n2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) -> Ranked n1 a -> Ranked n2 a rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr) -- | See the documentation of 'mlift2'. rlift2 :: forall n1 n2 n3 a. Elt a => SNat n3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b) -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2) rsumOuter1P :: forall n a. (Storable a, NumElt a) => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) rsumOuter1P (Ranked arr) | Refl <- lemReplicateSucc @(Nothing @Nat) @n = Ranked (msumOuter1P arr) rsumOuter1 :: forall n a. (NumElt a, PrimElt a) => Ranked (n + 1) a -> Ranked n a rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a rsumAllPrim (Ranked arr) = msumAllPrim arr rtranspose :: forall n a. Elt a => PermR -> Ranked n a -> Ranked n a rtranspose perm arr | sn@SNat <- rrank arr , Dict <- lemKnownReplicate sn , length perm <= fromIntegral (natVal (Proxy @n)) = rlift sn (\ssh' -> X.transposeUntyped (natSing @n) ssh' perm) arr | otherwise = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array" rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a rconcat | Refl <- lemReplicateSucc @(Nothing @Nat) @n = coerce mconcat rappend :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a rappend arr1 arr2 | sn@SNat <- rrank arr1 , Dict <- lemKnownReplicate sn , Refl <- lemReplicateSucc @(Nothing @Nat) @n = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) arr1 arr2 rscalar :: Elt a => a -> Ranked 0 a rscalar x = Ranked (mscalar x) rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a) rfromVectorP sh v | Dict <- lemKnownReplicate (shrRank sh) = Ranked (mfromVectorP (shCvtRX sh) v) rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a rfromVector sh v = rfromPrimitive (rfromVectorP sh v) rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a rtoVectorP = coerce mtoVectorP rtoVector :: PrimElt a => Ranked n a -> VS.Vector a rtoVector = coerce mtoVector rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a rfromListOuter l | Refl <- lemReplicateSucc @(Nothing @Nat) @n = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a rfromList1 l = Ranked (mfromList1 l) rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a rfromList1Prim l = Ranked (mfromList1Prim l) rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] rtoListOuter (Ranked arr) | Refl <- lemReplicateSucc @(Nothing @Nat) @n = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) rtoList1 :: Elt a => Ranked 1 a -> [a] rtoList1 = map runScalar . rtoListOuter rfromListPrim :: PrimElt a => [a] -> Ranked 1 a rfromListPrim l = let ssh = SUnknown () :!% ZKX xarr = X.fromList1 ssh l in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a rfromListPrimLinear sh l = let M_Primitive _ xarr = toPrimitive (mfromListPrim l) in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr) rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a rfromListLinear sh l = rreshape sh (rfromList1 l) rtoListLinear :: Elt a => Ranked n a -> [a] rtoListLinear (Ranked arr) = mtoListLinear arr rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a rfromOrthotope sn arr | Refl <- lemRankReplicate sn = let xarr = XArray arr in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr)) rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr))) | Refl <- lemRankReplicate (shrRank $ shCvtXR' sh) = arr runScalar :: Elt a => Ranked 0 a -> a runScalar arr = rindex arr ZIR rnest :: forall n m a. Elt a => SNat n -> Ranked (n + m) a -> Ranked n (Ranked m a) rnest n arr | Refl <- lemReplicatePlusApp n (Proxy @m) (Proxy @(Nothing @Nat)) = coerce (mnest (ssxFromSNat n) (coerce arr)) runNest :: forall n m a. Elt a => Ranked n (Ranked m a) -> Ranked (n + m) a runNest rarr@(Ranked (M_Ranked (M_Nest _ arr))) | Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat)) = Ranked arr rzip :: Ranked n a -> Ranked n b -> Ranked n (a, b) rzip = coerce mzip runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b) runzip = coerce munzip rrerankP :: forall n1 n2 n a b. (Storable a, Storable b) => SNat n -> IShR n2 -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b) rrerankP sn sh2 f (Ranked arr) | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat)) , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat)) = Ranked (mrerankP (ssxFromSNat sn) (shCvtRX sh2) (\a -> let Ranked r = f (Ranked a) in r) arr) -- | If there is a zero-sized dimension in the @n@-prefix of the shape of the -- input array, then there is no way to deduce the full shape of the output -- array (more precisely, the @n2@ part): that could only come from calling -- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in -- this case; we choose to fill the @n2@ part of the output shape with zeros. -- -- For example, if: -- -- @ -- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21] -- f :: Ranked 2 Int -> Ranked 3 Float -- @ -- -- then: -- -- @ -- rrerank _ _ _ f arr :: Ranked 5 Float -- @ -- -- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the -- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended -- to return an array with shape all-0 here (it probably didn't), but there is -- no better number to put here absent a subarray of the input to pass to @f@. rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b) => SNat n -> IShR n2 -> (Ranked n1 a -> Ranked n2 b) -> Ranked (n + n1) a -> Ranked (n + n2) b rrerank sn sh2 f (rtoPrimitive -> arr) = rfromPrimitive $ rrerankP sn sh2 (rtoPrimitive . f . rfromPrimitive) arr rreplicate :: forall n m a. Elt a => IShR n -> Ranked m a -> Ranked (n + m) a rreplicate sh (Ranked arr) | Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat)) = Ranked (mreplicate (shCvtRX sh) arr) rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) rreplicateScalP sh x | Dict <- lemKnownReplicate (shrRank sh) = Ranked (mreplicateScalP (shCvtRX sh) x) rreplicateScal :: forall n a. PrimElt a => IShR n -> a -> Ranked n a rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a rslice i n arr | Refl <- lemReplicateSucc @(Nothing @Nat) @n = rlift (rrank arr) (\_ -> X.sliceU i n) arr rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a rrev1 arr = rlift (rrank arr) (\(_ :: StaticShX sh') -> case lemReplicateSucc @(Nothing @Nat) @n of Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) arr rreshape :: forall n n' a. Elt a => IShR n' -> Ranked n a -> Ranked n' a rreshape sh' rarr@(Ranked arr) | Dict <- lemKnownReplicate (rrank rarr) , Dict <- lemKnownReplicate (shrRank sh') = Ranked (mreshape (shCvtRX sh') arr) rflatten :: Elt a => Ranked n a -> Ranked 1 a rflatten (Ranked arr) = mtoRanked (mflatten arr) riota :: (Enum a, PrimElt a) => Int -> Ranked 1 a riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota -- | Throws if the array is empty. rminIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n rminIndexPrim rarr@(Ranked arr) | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) = ixCvtXR (mminIndexPrim arr) -- | Throws if the array is empty. rmaxIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n rmaxIndexPrim rarr@(Ranked arr) | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) = ixCvtXR (mmaxIndexPrim arr) rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a rdot1Inner arr1 arr2 | SNat <- rrank arr1 , Refl <- lemReplicatePlusApp (SNat @n) (Proxy @1) (Proxy @(Nothing @Nat)) = coerce (mdot1Inner (Proxy @(Nothing @Nat))) arr1 arr2 -- | This has a temporary, suboptimal implementation in terms of 'mflatten'. -- Prefer 'rdot1Inner' if applicable. rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a rdot = coerce mdot rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) rtoXArrayPrimP (Ranked arr) = first shCvtXR' (mtoXArrayPrimP arr) rtoXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a) rtoXArrayPrim (Ranked arr) = first shCvtXR' (mtoXArrayPrim arr) rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a) rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr) rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a) rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr)