diff options
Diffstat (limited to 'src/Data/Array/Nested/Ranked.hs')
| -rw-r--r-- | src/Data/Array/Nested/Ranked.hs | 464 |
1 files changed, 134 insertions, 330 deletions
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs index fb5caa9..d687983 100644 --- a/src/Data/Array/Nested/Ranked.hs +++ b/src/Data/Array/Nested/Ranked.hs @@ -1,280 +1,82 @@ {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE InstanceSigs #-} -{-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Ranked where +module Data.Array.Nested.Ranked ( + Ranked(Ranked), + rquotArray, rremArray, ratan2Array, + rshape, rrank, + module Data.Array.Nested.Ranked, + liftRanked1, liftRanked2, +) where import Prelude hiding (mappend, mconcat) -import Control.DeepSeq (NFData(..)) -import Control.Monad.ST import Data.Array.RankedS qualified as S import Data.Bifunctor (first) import Data.Coerce (coerce) -import Data.Foldable (toList) -import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty) import Data.Proxy import Data.Type.Equality import Data.Vector.Storable qualified as VS import Foreign.Storable (Storable) -import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) -import GHC.Generics (Generic) import GHC.TypeLits import GHC.TypeNats qualified as TN -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Types -import Data.Array.Mixed.XArray (XArray(..)) -import Data.Array.Mixed.XArray qualified as X +import Data.Array.Nested.Convert +import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Ranked.Base import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Types import Data.Array.Strided.Arith - - --- | A rank-typed array: the number of dimensions of the array (its /rank/) is --- represented on the type level as a 'Nat'. --- --- Valid elements of a ranked arrays are described by the 'Elt' type class. --- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are --- supported (and are represented as a single, flattened, struct-of-arrays --- array internally). --- --- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. -type Ranked :: Nat -> Type -> Type -newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) -deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a) -deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a) - -instance (Show a, Elt a) => Show (Ranked n a) where - showsPrec d arr@(Ranked marr) = - let sh = show (toList (rshape arr)) - in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr - -instance Elt a => NFData (Ranked n a) where - rnf (Ranked arr) = rnf arr - --- just unwrap the newtype and defer to the general instance for nested arrays -newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) - deriving (Generic) - -deriving instance Eq (Mixed sh (Mixed (Replicate n Nothing) a)) => Eq (Mixed sh (Ranked n a)) - -newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) - --- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; --- these instances allow them to also be used as elements of arrays, thus --- making them first-class in the API. -instance Elt a => Elt (Ranked n a) where - mshape (M_Ranked arr) = mshape arr - mindex (M_Ranked arr) i = Ranked (mindex arr i) - - mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a) - mindexPartial (M_Ranked arr) i = - coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ - mindexPartial arr i - - mscalar (Ranked x) = M_Ranked (M_Nest ZSX x) - - mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) - mfromListOuter l = M_Ranked (mfromListOuter (coerce l)) - - mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] - mtoListOuter (M_Ranked arr) = - coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr) - - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) - mlift ssh2 f (M_Ranked arr) = - coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ - mlift ssh2 f arr - - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) - -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a) - mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) = - coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ - mlift2 ssh3 f arr1 arr2 - - mliftL :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) - -> NonEmpty (Mixed sh1 (Ranked n a)) -> NonEmpty (Mixed sh2 (Ranked n a)) - mliftL ssh2 f l = - coerce @(NonEmpty (Mixed sh2 (Mixed (Replicate n Nothing) a))) - @(NonEmpty (Mixed sh2 (Ranked n a))) $ - mliftL ssh2 f (coerce l) - - mcastPartial ssh1 ssh2 psh' (M_Ranked arr) = M_Ranked (mcastPartial ssh1 ssh2 psh' arr) - - mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr) - - mconcat l = M_Ranked (mconcat (coerce l)) - - mrnf (M_Ranked arr) = mrnf arr - - type ShapeTree (Ranked n a) = (IShR n, ShapeTree a) - - mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr) - - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - - mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t - - mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - - marrayStrides (M_Ranked arr) = marrayStrides arr - - mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () - mvecsWrite sh idx (Ranked arr) vecs = - mvecsWrite sh idx arr - (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) - vecs) - - mvecsWritePartial :: forall sh sh' s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) - -> MixedVecs s (sh ++ sh') (Ranked n a) - -> ST s () - mvecsWritePartial sh idx arr vecs = - mvecsWritePartial sh idx - (coerce @(Mixed sh' (Ranked n a)) - @(Mixed sh' (Mixed (Replicate n Nothing) a)) - arr) - (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) - @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) - vecs) - - mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) - mvecsFreeze sh vecs = - coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) - @(Mixed sh (Ranked n a)) - <$> mvecsFreeze sh - (coerce @(MixedVecs s sh (Ranked n a)) - @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) - vecs) - -instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where - memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a) - memptyArrayUnsafe i - | Dict <- lemKnownReplicate (SNat @n) - = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ - memptyArrayUnsafe i - - mvecsUnsafeNew idx (Ranked arr) - | Dict <- lemKnownReplicate (SNat @n) - = MV_Ranked <$> mvecsUnsafeNew idx arr - - mvecsNewEmpty _ - | Dict <- lemKnownReplicate (SNat @n) - = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) - - -liftRanked1 :: forall n a b. - (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b) - -> Ranked n a -> Ranked n b -liftRanked1 = coerce - -liftRanked2 :: forall n a b c. - (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b -> Mixed (Replicate n Nothing) c) - -> Ranked n a -> Ranked n b -> Ranked n c -liftRanked2 = coerce - -instance (NumElt a, PrimElt a) => Num (Ranked n a) where - (+) = liftRanked2 (+) - (-) = liftRanked2 (-) - (*) = liftRanked2 (*) - negate = liftRanked1 negate - abs = liftRanked1 abs - signum = liftRanked1 signum - fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicateScal" - -instance (FloatElt a, PrimElt a) => Fractional (Ranked n a) where - fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicateScal" - recip = liftRanked1 recip - (/) = liftRanked2 (/) - -instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where - pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicateScal" - exp = liftRanked1 exp - log = liftRanked1 log - sqrt = liftRanked1 sqrt - (**) = liftRanked2 (**) - logBase = liftRanked2 logBase - sin = liftRanked1 sin - cos = liftRanked1 cos - tan = liftRanked1 tan - asin = liftRanked1 asin - acos = liftRanked1 acos - atan = liftRanked1 atan - sinh = liftRanked1 sinh - cosh = liftRanked1 cosh - tanh = liftRanked1 tanh - asinh = liftRanked1 asinh - acosh = liftRanked1 acosh - atanh = liftRanked1 atanh - log1p = liftRanked1 GHC.Float.log1p - expm1 = liftRanked1 GHC.Float.expm1 - log1pexp = liftRanked1 GHC.Float.log1pexp - log1mexp = liftRanked1 GHC.Float.log1mexp - -rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a -rquotArray = liftRanked2 mquotArray -rremArray = liftRanked2 mremArray - -ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a -ratan2Array = liftRanked2 matan2Array +import Data.Array.XArray (XArray(..)) +import Data.Array.XArray qualified as X remptyArray :: KnownElt a => Ranked 1 a remptyArray = mtoRanked (memptyArray ZSX) -rshape :: Elt a => Ranked n a -> IShR n -rshape (Ranked arr) = shCvtXR' (mshape arr) - -rrank :: Elt a => Ranked n a -> SNat n -rrank = shrRank . rshape - -- | The total number of elements in the array. rsize :: Elt a => Ranked n a -> Int rsize = shrSize . rshape +{-# INLINEABLE rindex #-} rindex :: Elt a => Ranked n a -> IIxR n -> a -rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) +rindex (Ranked arr) idx = mindex arr (ixxFromIxR idx) +{-# INLINEABLE rindexPartial #-} rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a rindexPartial (Ranked arr) idx = Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) (castWith (subst2 (lemReplicatePlusApp (ixrRank idx) (Proxy @m) (Proxy @Nothing))) arr) - (ixCvtRX idx)) + (ixxFromIxR idx)) -- | __WARNING__: All values returned from the function must have equal shape. --- See the documentation of 'mgenerate' for more details. +-- See the documentation of 'mgenerate' for more details; see also +-- 'rgeneratePrim'. rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a rgenerate sh f | sn@SNat <- shrRank sh , Dict <- lemKnownReplicate sn , Refl <- lemRankReplicate sn - = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR)) + = Ranked (mgenerate (shxFromShR sh) (f . ixrFromIxX)) + +-- | See 'mgeneratePrim'. +{-# INLINE rgeneratePrim #-} +rgeneratePrim :: forall n a i. (PrimElt a, Num i) + => IShR n -> (IxR n i -> a) -> Ranked n a +rgeneratePrim sh f = + let g i = f (ixrFromLinear sh i) + in rfromVector sh $ VS.generate (shrSize sh) g -- | See the documentation of 'mlift'. rlift :: forall n1 n2 a. Elt a @@ -290,16 +92,19 @@ rlift2 :: forall n1 n2 n3 a. Elt a -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2) -rsumOuter1P :: forall n a. - (Storable a, NumElt a) - => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) -rsumOuter1P (Ranked arr) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = Ranked (msumOuter1P arr) +rsumOuter1PrimP :: forall n a. + (Storable a, NumElt a) + => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) +rsumOuter1PrimP (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = Ranked (msumOuter1PrimP arr) -rsumOuter1 :: forall n a. (NumElt a, PrimElt a) - => Ranked (n + 1) a -> Ranked n a -rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive +rsumOuter1Prim :: forall n a. (NumElt a, PrimElt a) + => Ranked (n + 1) a -> Ranked n a +rsumOuter1Prim = rfromPrimitive . rsumOuter1PrimP . rtoPrimitive + +rsumAllPrimP :: (Storable a, NumElt a) => Ranked n (Primitive a) -> a +rsumAllPrimP (Ranked arr) = msumAllPrimP arr rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a rsumAllPrim (Ranked arr) = msumAllPrim arr @@ -317,7 +122,7 @@ rtranspose perm arr rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a rconcat - | Refl <- lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = coerce mconcat rappend :: forall n a. Elt a @@ -325,7 +130,7 @@ rappend :: forall n a. Elt a rappend arr1 arr2 | sn@SNat <- rrank arr1 , Dict <- lemKnownReplicate sn - , Refl <- lemReplicateSucc @(Nothing @Nat) @n + , Refl <- lemReplicateSucc @(Nothing @Nat) (SNat @n) = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) arr1 arr2 @@ -335,7 +140,7 @@ rscalar x = Ranked (mscalar x) rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a) rfromVectorP sh v | Dict <- lemKnownReplicate (shrRank sh) - = Ranked (mfromVectorP (shCvtRX sh) v) + = Ranked (mfromVectorP (shxFromShR sh) v) rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a rfromVector sh v = rfromPrimitive (rfromVectorP sh v) @@ -346,38 +151,63 @@ rtoVectorP = coerce mtoVectorP rtoVector :: PrimElt a => Ranked n a -> VS.Vector a rtoVector = coerce mtoVector +-- | All arrays in the list, even subarrays inside @a@, must have the same +-- shape; if they do not, a runtime error will be thrown. See the +-- documentation of 'mgenerate' for more information about this restriction. +-- +-- Because the length of the 'NonEmpty' list is unknown, its spine must be +-- materialised in memory in order to compute its length. If its length is +-- already known, use 'rfromListOuterN' to be able to stream the list. +-- +-- If your array is 1-dimensional and contains scalars, use 'rfromList1Prim'. rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a rfromListOuter l - | Refl <- lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) +-- | See 'rfromListOuter'. If the list does not have the given length, a +-- runtime error is thrown. 'rfromList1PrimN' is faster if applicable. +rfromListOuterN :: forall n a. Elt a => Int -> NonEmpty (Ranked n a) -> Ranked (n + 1) a +rfromListOuterN n l + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = Ranked (mfromListOuterN n (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) + +-- | Because the length of the 'NonEmpty' list is unknown, its spine must be +-- materialised in memory in order to compute its length. If its length is +-- already known, use 'rfromList1N' to be able to stream the list. +-- +-- If the elements are scalars, 'rfromList1Prim' is faster. rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a -rfromList1 l = Ranked (mfromList1 l) +rfromList1 = coerce mfromList1 -rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a -rfromList1Prim l = Ranked (mfromList1Prim l) +-- | If the elements are scalars, 'rfromList1PrimN' is faster. A runtime error +-- is thrown if the list length does not match the given length. +rfromList1N :: Elt a => Int -> NonEmpty a -> Ranked 1 a +rfromList1N = coerce mfromList1N -rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] -rtoListOuter (Ranked arr) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) +-- | If the elements are scalars, 'rfromListPrimLinear' is faster. +rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a +rfromListLinear sh l = Ranked (mfromListLinear (shxFromShR sh) l) -rtoList1 :: Elt a => Ranked 1 a -> [a] -rtoList1 = map runScalar . rtoListOuter +-- | Because the length of the list is unknown, its spine must be materialised +-- in memory in order to compute its length. If its length is already known, +-- use 'rfromList1PrimN' to be able to stream the list. +rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a +rfromList1Prim = coerce mfromList1Prim -rfromListPrim :: PrimElt a => [a] -> Ranked 1 a -rfromListPrim l = - let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l - in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr +rfromList1PrimN :: PrimElt a => Int -> [a] -> Ranked 1 a +rfromList1PrimN = coerce mfromList1PrimN -rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a -rfromListPrimLinear sh l = - let M_Primitive _ xarr = toPrimitive (mfromListPrim l) - in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr) +rfromListPrimLinear :: forall n a. PrimElt a => IShR n -> [a] -> Ranked n a +rfromListPrimLinear sh l = Ranked (mfromListPrimLinear (shxFromShR sh) l) -rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a -rfromListLinear sh l = rreshape sh (rfromList1 l) +rtoList :: Elt a => Ranked 1 a -> [a] +rtoList = map runScalar . rtoListOuter + +rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] +rtoListOuter (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) rtoListLinear :: Elt a => Ranked n a -> [a] rtoListLinear (Ranked arr) = mtoListLinear arr @@ -388,9 +218,9 @@ rfromOrthotope sn arr = let xarr = XArray arr in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr)) -rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a +rtoOrthotope :: forall a n. PrimElt a => Ranked n a -> S.Array n a rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr))) - | Refl <- lemRankReplicate (shrRank $ shCvtXR' sh) + | Refl <- lemRankReplicate (shrRank $ shrFromShX2 @n sh) = arr runScalar :: Elt a => Ranked 0 a -> a @@ -406,22 +236,20 @@ runNest rarr@(Ranked (M_Ranked (M_Nest _ arr))) | Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat)) = Ranked arr -rzip :: Ranked n a -> Ranked n b -> Ranked n (a, b) +rzip :: (Elt a, Elt b) => Ranked n a -> Ranked n b -> Ranked n (a, b) rzip = coerce mzip runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b) runzip = coerce munzip -rrerankP :: forall n1 n2 n a b. (Storable a, Storable b) - => SNat n -> IShR n2 - -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) - -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b) -rrerankP sn sh2 f (Ranked arr) - | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat)) - , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat)) - = Ranked (mrerankP (ssxFromSNat sn) (shCvtRX sh2) - (\a -> let Ranked r = f (Ranked a) in r) - arr) +rrerankPrimP :: forall n1 n2 n a b. (Storable a, Storable b) + => IShR n2 + -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) + -> Ranked n (Ranked n1 (Primitive a)) -> Ranked n (Ranked n2 (Primitive b)) +rrerankPrimP sh2 f (Ranked (M_Ranked arr)) + = Ranked (M_Ranked (mrerankPrimP (shxFromShR sh2) + (\a -> let Ranked r = f (Ranked a) in r) + arr)) -- | If there is a zero-sized dimension in the @n@-prefix of the shape of the -- input array, then there is no way to deduce the full shape of the output @@ -432,63 +260,60 @@ rrerankP sn sh2 f (Ranked arr) -- For example, if: -- -- @ --- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21] +-- arr :: Ranked 3 (Ranked 2 Int) -- outer array shape [3, 0, 4]; inner shape [2, 21] -- f :: Ranked 2 Int -> Ranked 3 Float -- @ -- -- then: -- -- @ --- rrerank _ _ _ f arr :: Ranked 5 Float +-- rrerank _ f arr :: Ranked 3 (Ranked 3 Float) -- @ -- --- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the --- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended --- to return an array with shape all-0 here (it probably didn't), but there is --- no better number to put here absent a subarray of the input to pass to @f@. -rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b) - => SNat n -> IShR n2 - -> (Ranked n1 a -> Ranked n2 b) - -> Ranked (n + n1) a -> Ranked (n + n2) b -rrerank sn sh2 f (rtoPrimitive -> arr) = - rfromPrimitive $ rrerankP sn sh2 (rtoPrimitive . f . rfromPrimitive) arr +-- and the inner arrays of the result will have shape @[0, 0, 0]@. We don't +-- know if @f@ intended to return an array with all-zero shape here (it +-- probably didn't), but there is no better number to put here absent a +-- subarray of the input to pass to @f@. +rrerankPrim :: forall n1 n2 n a b. (PrimElt a, PrimElt b) + => IShR n2 + -> (Ranked n1 a -> Ranked n2 b) + -> Ranked n (Ranked n1 a) -> Ranked n (Ranked n2 b) +rrerankPrim sh2 f (Ranked (M_Ranked arr)) = + Ranked (M_Ranked (mrerankPrim (shxFromShR sh2) + (\a -> let Ranked r = f (Ranked a) in r) + arr)) rreplicate :: forall n m a. Elt a => IShR n -> Ranked m a -> Ranked (n + m) a rreplicate sh (Ranked arr) | Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat)) - = Ranked (mreplicate (shCvtRX sh) arr) + = Ranked (mreplicate (shxFromShR sh) arr) -rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) -rreplicateScalP sh x +rreplicatePrimP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) +rreplicatePrimP sh x | Dict <- lemKnownReplicate (shrRank sh) - = Ranked (mreplicateScalP (shCvtRX sh) x) + = Ranked (mreplicatePrimP (shxFromShR sh) x) -rreplicateScal :: forall n a. PrimElt a +rreplicatePrim :: forall n a. PrimElt a => IShR n -> a -> Ranked n a -rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) +rreplicatePrim sh x = rfromPrimitive (rreplicatePrimP sh x) rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a -rslice i n arr - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = rlift (rrank arr) - (\_ -> X.sliceU i n) - arr +rslice i n (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = Ranked (msliceN i n arr) rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a -rrev1 arr = - rlift (rrank arr) - (\(_ :: StaticShX sh') -> - case lemReplicateSucc @(Nothing @Nat) @n of - Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) - arr +rrev1 (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = Ranked (mrev1 arr) rreshape :: forall n n' a. Elt a => IShR n' -> Ranked n a -> Ranked n' a rreshape sh' rarr@(Ranked arr) | Dict <- lemKnownReplicate (rrank rarr) , Dict <- lemKnownReplicate (shrRank sh') - = Ranked (mreshape (shCvtRX sh') arr) + = Ranked (mreshape (shxFromShR sh') arr) rflatten :: Elt a => Ranked n a -> Ranked 1 a rflatten (Ranked arr) = mtoRanked (mflatten arr) @@ -500,13 +325,13 @@ riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota rminIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n rminIndexPrim rarr@(Ranked arr) | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) - = ixCvtXR (mminIndexPrim arr) + = ixrFromIxX (mminIndexPrim arr) -- | Throws if the array is empty. rmaxIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n rmaxIndexPrim rarr@(Ranked arr) | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) - = ixCvtXR (mmaxIndexPrim arr) + = ixrFromIxX (mmaxIndexPrim arr) rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a rdot1Inner arr1 arr2 @@ -520,40 +345,19 @@ rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a rdot = coerce mdot rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) -rtoXArrayPrimP (Ranked arr) = first shCvtXR' (mtoXArrayPrimP arr) +rtoXArrayPrimP (Ranked arr) = first shrFromShX2 (mtoXArrayPrimP arr) rtoXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a) -rtoXArrayPrim (Ranked arr) = first shCvtXR' (mtoXArrayPrim arr) +rtoXArrayPrim (Ranked arr) = first shrFromShX2 (mtoXArrayPrim arr) rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a) -rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) +rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShX (X.shape (ssxFromSNat sn) arr)) arr) rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a -rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) +rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShX (X.shape (ssxFromSNat sn) arr)) arr) rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr) rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a) rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr) - -mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a -mtoRanked arr - | Refl <- lemRankReplicate (shxRank (mshape arr)) - = Ranked (mcast (ssxFromShape (convSh (mshape arr))) arr) - where - convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing) - convSh ZSX = ZSX - convSh (smn :$% (sh :: IShX sh'T)) - | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T) - = SUnknown (fromSMayNat' smn) :$% convSh sh - -rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a -rtoMixed (Ranked arr) = arr - --- | A more weakly-typed version of 'rtoMixed' that does a runtime shape --- compatibility check. -rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a -rcastToMixed sshx rarr@(Ranked arr) - | Refl <- lemRankReplicate (rrank rarr) - = mcast sshx arr |
