aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Ranked.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Ranked.hs')
-rw-r--r--src/Data/Array/Nested/Ranked.hs464
1 files changed, 134 insertions, 330 deletions
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
index fb5caa9..d687983 100644
--- a/src/Data/Array/Nested/Ranked.hs
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -1,280 +1,82 @@
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE DerivingVia #-}
-{-# LANGUAGE FlexibleContexts #-}
-{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ImportQualifiedPost #-}
-{-# LANGUAGE InstanceSigs #-}
-{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Nested.Ranked where
+module Data.Array.Nested.Ranked (
+ Ranked(Ranked),
+ rquotArray, rremArray, ratan2Array,
+ rshape, rrank,
+ module Data.Array.Nested.Ranked,
+ liftRanked1, liftRanked2,
+) where
import Prelude hiding (mappend, mconcat)
-import Control.DeepSeq (NFData(..))
-import Control.Monad.ST
import Data.Array.RankedS qualified as S
import Data.Bifunctor (first)
import Data.Coerce (coerce)
-import Data.Foldable (toList)
-import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
import Data.Type.Equality
import Data.Vector.Storable qualified as VS
import Foreign.Storable (Storable)
-import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
-import GHC.Generics (Generic)
import GHC.TypeLits
import GHC.TypeNats qualified as TN
-import Data.Array.Mixed.Lemmas
-import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Types
-import Data.Array.Mixed.XArray (XArray(..))
-import Data.Array.Mixed.XArray qualified as X
+import Data.Array.Nested.Convert
+import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Permutation
+import Data.Array.Nested.Ranked.Base
import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types
import Data.Array.Strided.Arith
-
-
--- | A rank-typed array: the number of dimensions of the array (its /rank/) is
--- represented on the type level as a 'Nat'.
---
--- Valid elements of a ranked arrays are described by the 'Elt' type class.
--- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are
--- supported (and are represented as a single, flattened, struct-of-arrays
--- array internally).
---
--- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's.
-type Ranked :: Nat -> Type -> Type
-newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
-deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a)
-deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a)
-
-instance (Show a, Elt a) => Show (Ranked n a) where
- showsPrec d arr@(Ranked marr) =
- let sh = show (toList (rshape arr))
- in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr
-
-instance Elt a => NFData (Ranked n a) where
- rnf (Ranked arr) = rnf arr
-
--- just unwrap the newtype and defer to the general instance for nested arrays
-newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a))
- deriving (Generic)
-
-deriving instance Eq (Mixed sh (Mixed (Replicate n Nothing) a)) => Eq (Mixed sh (Ranked n a))
-
-newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a))
-
--- 'Ranked' and 'Shaped' can already be used at the top level of an array nest;
--- these instances allow them to also be used as elements of arrays, thus
--- making them first-class in the API.
-instance Elt a => Elt (Ranked n a) where
- mshape (M_Ranked arr) = mshape arr
- mindex (M_Ranked arr) i = Ranked (mindex arr i)
-
- mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a)
- mindexPartial (M_Ranked arr) i =
- coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
- mindexPartial arr i
-
- mscalar (Ranked x) = M_Ranked (M_Nest ZSX x)
-
- mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a)
- mfromListOuter l = M_Ranked (mfromListOuter (coerce l))
-
- mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)]
- mtoListOuter (M_Ranked arr) =
- coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr)
-
- mlift :: forall sh1 sh2.
- StaticShX sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
- -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a)
- mlift ssh2 f (M_Ranked arr) =
- coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $
- mlift ssh2 f arr
-
- mlift2 :: forall sh1 sh2 sh3.
- StaticShX sh3
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
- -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a)
- mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) =
- coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
- mlift2 ssh3 f arr1 arr2
-
- mliftL :: forall sh1 sh2.
- StaticShX sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
- -> NonEmpty (Mixed sh1 (Ranked n a)) -> NonEmpty (Mixed sh2 (Ranked n a))
- mliftL ssh2 f l =
- coerce @(NonEmpty (Mixed sh2 (Mixed (Replicate n Nothing) a)))
- @(NonEmpty (Mixed sh2 (Ranked n a))) $
- mliftL ssh2 f (coerce l)
-
- mcastPartial ssh1 ssh2 psh' (M_Ranked arr) = M_Ranked (mcastPartial ssh1 ssh2 psh' arr)
-
- mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr)
-
- mconcat l = M_Ranked (mconcat (coerce l))
-
- mrnf (M_Ranked arr) = mrnf arr
-
- type ShapeTree (Ranked n a) = (IShR n, ShapeTree a)
-
- mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr)
-
- mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
-
- mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
-
- mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
-
- marrayStrides (M_Ranked arr) = marrayStrides arr
-
- mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
- mvecsWrite sh idx (Ranked arr) vecs =
- mvecsWrite sh idx arr
- (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
- vecs)
-
- mvecsWritePartial :: forall sh sh' s.
- IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
- -> MixedVecs s (sh ++ sh') (Ranked n a)
- -> ST s ()
- mvecsWritePartial sh idx arr vecs =
- mvecsWritePartial sh idx
- (coerce @(Mixed sh' (Ranked n a))
- @(Mixed sh' (Mixed (Replicate n Nothing) a))
- arr)
- (coerce @(MixedVecs s (sh ++ sh') (Ranked n a))
- @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
- vecs)
-
- mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
- mvecsFreeze sh vecs =
- coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
- @(Mixed sh (Ranked n a))
- <$> mvecsFreeze sh
- (coerce @(MixedVecs s sh (Ranked n a))
- @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
- vecs)
-
-instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
- memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a)
- memptyArrayUnsafe i
- | Dict <- lemKnownReplicate (SNat @n)
- = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
- memptyArrayUnsafe i
-
- mvecsUnsafeNew idx (Ranked arr)
- | Dict <- lemKnownReplicate (SNat @n)
- = MV_Ranked <$> mvecsUnsafeNew idx arr
-
- mvecsNewEmpty _
- | Dict <- lemKnownReplicate (SNat @n)
- = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
-
-
-liftRanked1 :: forall n a b.
- (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b)
- -> Ranked n a -> Ranked n b
-liftRanked1 = coerce
-
-liftRanked2 :: forall n a b c.
- (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b -> Mixed (Replicate n Nothing) c)
- -> Ranked n a -> Ranked n b -> Ranked n c
-liftRanked2 = coerce
-
-instance (NumElt a, PrimElt a) => Num (Ranked n a) where
- (+) = liftRanked2 (+)
- (-) = liftRanked2 (-)
- (*) = liftRanked2 (*)
- negate = liftRanked1 negate
- abs = liftRanked1 abs
- signum = liftRanked1 signum
- fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicateScal"
-
-instance (FloatElt a, PrimElt a) => Fractional (Ranked n a) where
- fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicateScal"
- recip = liftRanked1 recip
- (/) = liftRanked2 (/)
-
-instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where
- pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicateScal"
- exp = liftRanked1 exp
- log = liftRanked1 log
- sqrt = liftRanked1 sqrt
- (**) = liftRanked2 (**)
- logBase = liftRanked2 logBase
- sin = liftRanked1 sin
- cos = liftRanked1 cos
- tan = liftRanked1 tan
- asin = liftRanked1 asin
- acos = liftRanked1 acos
- atan = liftRanked1 atan
- sinh = liftRanked1 sinh
- cosh = liftRanked1 cosh
- tanh = liftRanked1 tanh
- asinh = liftRanked1 asinh
- acosh = liftRanked1 acosh
- atanh = liftRanked1 atanh
- log1p = liftRanked1 GHC.Float.log1p
- expm1 = liftRanked1 GHC.Float.expm1
- log1pexp = liftRanked1 GHC.Float.log1pexp
- log1mexp = liftRanked1 GHC.Float.log1mexp
-
-rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a
-rquotArray = liftRanked2 mquotArray
-rremArray = liftRanked2 mremArray
-
-ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a
-ratan2Array = liftRanked2 matan2Array
+import Data.Array.XArray (XArray(..))
+import Data.Array.XArray qualified as X
remptyArray :: KnownElt a => Ranked 1 a
remptyArray = mtoRanked (memptyArray ZSX)
-rshape :: Elt a => Ranked n a -> IShR n
-rshape (Ranked arr) = shCvtXR' (mshape arr)
-
-rrank :: Elt a => Ranked n a -> SNat n
-rrank = shrRank . rshape
-
-- | The total number of elements in the array.
rsize :: Elt a => Ranked n a -> Int
rsize = shrSize . rshape
+{-# INLINEABLE rindex #-}
rindex :: Elt a => Ranked n a -> IIxR n -> a
-rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
+rindex (Ranked arr) idx = mindex arr (ixxFromIxR idx)
+{-# INLINEABLE rindexPartial #-}
rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a
rindexPartial (Ranked arr) idx =
Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
(castWith (subst2 (lemReplicatePlusApp (ixrRank idx) (Proxy @m) (Proxy @Nothing))) arr)
- (ixCvtRX idx))
+ (ixxFromIxR idx))
-- | __WARNING__: All values returned from the function must have equal shape.
--- See the documentation of 'mgenerate' for more details.
+-- See the documentation of 'mgenerate' for more details; see also
+-- 'rgeneratePrim'.
rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a
rgenerate sh f
| sn@SNat <- shrRank sh
, Dict <- lemKnownReplicate sn
, Refl <- lemRankReplicate sn
- = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR))
+ = Ranked (mgenerate (shxFromShR sh) (f . ixrFromIxX))
+
+-- | See 'mgeneratePrim'.
+{-# INLINE rgeneratePrim #-}
+rgeneratePrim :: forall n a i. (PrimElt a, Num i)
+ => IShR n -> (IxR n i -> a) -> Ranked n a
+rgeneratePrim sh f =
+ let g i = f (ixrFromLinear sh i)
+ in rfromVector sh $ VS.generate (shrSize sh) g
-- | See the documentation of 'mlift'.
rlift :: forall n1 n2 a. Elt a
@@ -290,16 +92,19 @@ rlift2 :: forall n1 n2 n3 a. Elt a
-> Ranked n1 a -> Ranked n2 a -> Ranked n3 a
rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2)
-rsumOuter1P :: forall n a.
- (Storable a, NumElt a)
- => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
-rsumOuter1P (Ranked arr)
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = Ranked (msumOuter1P arr)
+rsumOuter1PrimP :: forall n a.
+ (Storable a, NumElt a)
+ => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
+rsumOuter1PrimP (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = Ranked (msumOuter1PrimP arr)
-rsumOuter1 :: forall n a. (NumElt a, PrimElt a)
- => Ranked (n + 1) a -> Ranked n a
-rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive
+rsumOuter1Prim :: forall n a. (NumElt a, PrimElt a)
+ => Ranked (n + 1) a -> Ranked n a
+rsumOuter1Prim = rfromPrimitive . rsumOuter1PrimP . rtoPrimitive
+
+rsumAllPrimP :: (Storable a, NumElt a) => Ranked n (Primitive a) -> a
+rsumAllPrimP (Ranked arr) = msumAllPrimP arr
rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a
rsumAllPrim (Ranked arr) = msumAllPrim arr
@@ -317,7 +122,7 @@ rtranspose perm arr
rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a
rconcat
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
= coerce mconcat
rappend :: forall n a. Elt a
@@ -325,7 +130,7 @@ rappend :: forall n a. Elt a
rappend arr1 arr2
| sn@SNat <- rrank arr1
, Dict <- lemKnownReplicate sn
- , Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ , Refl <- lemReplicateSucc @(Nothing @Nat) (SNat @n)
= coerce (mappend @Nothing @Nothing @(Replicate n Nothing))
arr1 arr2
@@ -335,7 +140,7 @@ rscalar x = Ranked (mscalar x)
rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a)
rfromVectorP sh v
| Dict <- lemKnownReplicate (shrRank sh)
- = Ranked (mfromVectorP (shCvtRX sh) v)
+ = Ranked (mfromVectorP (shxFromShR sh) v)
rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a
rfromVector sh v = rfromPrimitive (rfromVectorP sh v)
@@ -346,38 +151,63 @@ rtoVectorP = coerce mtoVectorP
rtoVector :: PrimElt a => Ranked n a -> VS.Vector a
rtoVector = coerce mtoVector
+-- | All arrays in the list, even subarrays inside @a@, must have the same
+-- shape; if they do not, a runtime error will be thrown. See the
+-- documentation of 'mgenerate' for more information about this restriction.
+--
+-- Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'rfromListOuterN' to be able to stream the list.
+--
+-- If your array is 1-dimensional and contains scalars, use 'rfromList1Prim'.
rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
rfromListOuter l
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
= Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
+-- | See 'rfromListOuter'. If the list does not have the given length, a
+-- runtime error is thrown. 'rfromList1PrimN' is faster if applicable.
+rfromListOuterN :: forall n a. Elt a => Int -> NonEmpty (Ranked n a) -> Ranked (n + 1) a
+rfromListOuterN n l
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = Ranked (mfromListOuterN n (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
+
+-- | Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'rfromList1N' to be able to stream the list.
+--
+-- If the elements are scalars, 'rfromList1Prim' is faster.
rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a
-rfromList1 l = Ranked (mfromList1 l)
+rfromList1 = coerce mfromList1
-rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a
-rfromList1Prim l = Ranked (mfromList1Prim l)
+-- | If the elements are scalars, 'rfromList1PrimN' is faster. A runtime error
+-- is thrown if the list length does not match the given length.
+rfromList1N :: Elt a => Int -> NonEmpty a -> Ranked 1 a
+rfromList1N = coerce mfromList1N
-rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
-rtoListOuter (Ranked arr)
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)
+-- | If the elements are scalars, 'rfromListPrimLinear' is faster.
+rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a
+rfromListLinear sh l = Ranked (mfromListLinear (shxFromShR sh) l)
-rtoList1 :: Elt a => Ranked 1 a -> [a]
-rtoList1 = map runScalar . rtoListOuter
+-- | Because the length of the list is unknown, its spine must be materialised
+-- in memory in order to compute its length. If its length is already known,
+-- use 'rfromList1PrimN' to be able to stream the list.
+rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a
+rfromList1Prim = coerce mfromList1Prim
-rfromListPrim :: PrimElt a => [a] -> Ranked 1 a
-rfromListPrim l =
- let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
- in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+rfromList1PrimN :: PrimElt a => Int -> [a] -> Ranked 1 a
+rfromList1PrimN = coerce mfromList1PrimN
-rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a
-rfromListPrimLinear sh l =
- let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
- in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr)
+rfromListPrimLinear :: forall n a. PrimElt a => IShR n -> [a] -> Ranked n a
+rfromListPrimLinear sh l = Ranked (mfromListPrimLinear (shxFromShR sh) l)
-rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a
-rfromListLinear sh l = rreshape sh (rfromList1 l)
+rtoList :: Elt a => Ranked 1 a -> [a]
+rtoList = map runScalar . rtoListOuter
+
+rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
+rtoListOuter (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)
rtoListLinear :: Elt a => Ranked n a -> [a]
rtoListLinear (Ranked arr) = mtoListLinear arr
@@ -388,9 +218,9 @@ rfromOrthotope sn arr
= let xarr = XArray arr
in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr))
-rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a
+rtoOrthotope :: forall a n. PrimElt a => Ranked n a -> S.Array n a
rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr)))
- | Refl <- lemRankReplicate (shrRank $ shCvtXR' sh)
+ | Refl <- lemRankReplicate (shrRank $ shrFromShX2 @n sh)
= arr
runScalar :: Elt a => Ranked 0 a -> a
@@ -406,22 +236,20 @@ runNest rarr@(Ranked (M_Ranked (M_Nest _ arr)))
| Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat))
= Ranked arr
-rzip :: Ranked n a -> Ranked n b -> Ranked n (a, b)
+rzip :: (Elt a, Elt b) => Ranked n a -> Ranked n b -> Ranked n (a, b)
rzip = coerce mzip
runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b)
runzip = coerce munzip
-rrerankP :: forall n1 n2 n a b. (Storable a, Storable b)
- => SNat n -> IShR n2
- -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b))
- -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b)
-rrerankP sn sh2 f (Ranked arr)
- | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat))
- , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat))
- = Ranked (mrerankP (ssxFromSNat sn) (shCvtRX sh2)
- (\a -> let Ranked r = f (Ranked a) in r)
- arr)
+rrerankPrimP :: forall n1 n2 n a b. (Storable a, Storable b)
+ => IShR n2
+ -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b))
+ -> Ranked n (Ranked n1 (Primitive a)) -> Ranked n (Ranked n2 (Primitive b))
+rrerankPrimP sh2 f (Ranked (M_Ranked arr))
+ = Ranked (M_Ranked (mrerankPrimP (shxFromShR sh2)
+ (\a -> let Ranked r = f (Ranked a) in r)
+ arr))
-- | If there is a zero-sized dimension in the @n@-prefix of the shape of the
-- input array, then there is no way to deduce the full shape of the output
@@ -432,63 +260,60 @@ rrerankP sn sh2 f (Ranked arr)
-- For example, if:
--
-- @
--- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21]
+-- arr :: Ranked 3 (Ranked 2 Int) -- outer array shape [3, 0, 4]; inner shape [2, 21]
-- f :: Ranked 2 Int -> Ranked 3 Float
-- @
--
-- then:
--
-- @
--- rrerank _ _ _ f arr :: Ranked 5 Float
+-- rrerank _ f arr :: Ranked 3 (Ranked 3 Float)
-- @
--
--- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the
--- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended
--- to return an array with shape all-0 here (it probably didn't), but there is
--- no better number to put here absent a subarray of the input to pass to @f@.
-rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b)
- => SNat n -> IShR n2
- -> (Ranked n1 a -> Ranked n2 b)
- -> Ranked (n + n1) a -> Ranked (n + n2) b
-rrerank sn sh2 f (rtoPrimitive -> arr) =
- rfromPrimitive $ rrerankP sn sh2 (rtoPrimitive . f . rfromPrimitive) arr
+-- and the inner arrays of the result will have shape @[0, 0, 0]@. We don't
+-- know if @f@ intended to return an array with all-zero shape here (it
+-- probably didn't), but there is no better number to put here absent a
+-- subarray of the input to pass to @f@.
+rrerankPrim :: forall n1 n2 n a b. (PrimElt a, PrimElt b)
+ => IShR n2
+ -> (Ranked n1 a -> Ranked n2 b)
+ -> Ranked n (Ranked n1 a) -> Ranked n (Ranked n2 b)
+rrerankPrim sh2 f (Ranked (M_Ranked arr)) =
+ Ranked (M_Ranked (mrerankPrim (shxFromShR sh2)
+ (\a -> let Ranked r = f (Ranked a) in r)
+ arr))
rreplicate :: forall n m a. Elt a
=> IShR n -> Ranked m a -> Ranked (n + m) a
rreplicate sh (Ranked arr)
| Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat))
- = Ranked (mreplicate (shCvtRX sh) arr)
+ = Ranked (mreplicate (shxFromShR sh) arr)
-rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
-rreplicateScalP sh x
+rreplicatePrimP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
+rreplicatePrimP sh x
| Dict <- lemKnownReplicate (shrRank sh)
- = Ranked (mreplicateScalP (shCvtRX sh) x)
+ = Ranked (mreplicatePrimP (shxFromShR sh) x)
-rreplicateScal :: forall n a. PrimElt a
+rreplicatePrim :: forall n a. PrimElt a
=> IShR n -> a -> Ranked n a
-rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)
+rreplicatePrim sh x = rfromPrimitive (rreplicatePrimP sh x)
rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
-rslice i n arr
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = rlift (rrank arr)
- (\_ -> X.sliceU i n)
- arr
+rslice i n (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = Ranked (msliceN i n arr)
rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a
-rrev1 arr =
- rlift (rrank arr)
- (\(_ :: StaticShX sh') ->
- case lemReplicateSucc @(Nothing @Nat) @n of
- Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh'))
- arr
+rrev1 (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = Ranked (mrev1 arr)
rreshape :: forall n n' a. Elt a
=> IShR n' -> Ranked n a -> Ranked n' a
rreshape sh' rarr@(Ranked arr)
| Dict <- lemKnownReplicate (rrank rarr)
, Dict <- lemKnownReplicate (shrRank sh')
- = Ranked (mreshape (shCvtRX sh') arr)
+ = Ranked (mreshape (shxFromShR sh') arr)
rflatten :: Elt a => Ranked n a -> Ranked 1 a
rflatten (Ranked arr) = mtoRanked (mflatten arr)
@@ -500,13 +325,13 @@ riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota
rminIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n
rminIndexPrim rarr@(Ranked arr)
| Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
- = ixCvtXR (mminIndexPrim arr)
+ = ixrFromIxX (mminIndexPrim arr)
-- | Throws if the array is empty.
rmaxIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n
rmaxIndexPrim rarr@(Ranked arr)
| Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
- = ixCvtXR (mmaxIndexPrim arr)
+ = ixrFromIxX (mmaxIndexPrim arr)
rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a
rdot1Inner arr1 arr2
@@ -520,40 +345,19 @@ rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a
rdot = coerce mdot
rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a)
-rtoXArrayPrimP (Ranked arr) = first shCvtXR' (mtoXArrayPrimP arr)
+rtoXArrayPrimP (Ranked arr) = first shrFromShX2 (mtoXArrayPrimP arr)
rtoXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a)
-rtoXArrayPrim (Ranked arr) = first shCvtXR' (mtoXArrayPrim arr)
+rtoXArrayPrim (Ranked arr) = first shrFromShX2 (mtoXArrayPrim arr)
rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a)
-rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr)
+rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShX (X.shape (ssxFromSNat sn) arr)) arr)
rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a
-rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr)
+rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShX (X.shape (ssxFromSNat sn) arr)) arr)
rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a
rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr)
rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a)
rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr)
-
-mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
-mtoRanked arr
- | Refl <- lemRankReplicate (shxRank (mshape arr))
- = Ranked (mcast (ssxFromShape (convSh (mshape arr))) arr)
- where
- convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing)
- convSh ZSX = ZSX
- convSh (smn :$% (sh :: IShX sh'T))
- | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T)
- = SUnknown (fromSMayNat' smn) :$% convSh sh
-
-rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a
-rtoMixed (Ranked arr) = arr
-
--- | A more weakly-typed version of 'rtoMixed' that does a runtime shape
--- compatibility check.
-rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a
-rcastToMixed sshx rarr@(Ranked arr)
- | Refl <- lemRankReplicate (rrank rarr)
- = mcast sshx arr