aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Ranked/Base.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Ranked/Base.hs')
-rw-r--r--src/Data/Array/Nested/Ranked/Base.hs61
1 files changed, 32 insertions, 29 deletions
diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs
index 11a8ffb..beedbcf 100644
--- a/src/Data/Array/Nested/Ranked/Base.hs
+++ b/src/Data/Array/Nested/Ranked/Base.hs
@@ -26,16 +26,11 @@ import Data.Coerce (coerce)
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
-import Data.Type.Equality
import Foreign.Storable (Storable)
import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
import GHC.TypeLits
-#ifndef OXAR_DEFAULT_SHOW_INSTANCES
-import Data.Foldable (toList)
-#endif
-
import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
@@ -65,7 +60,7 @@ deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a)
#ifndef OXAR_DEFAULT_SHOW_INSTANCES
instance (Show a, Elt a) => Show (Ranked n a) where
showsPrec d arr@(Ranked marr) =
- let sh = show (toList (rshape arr))
+ let sh = show (shrToList (rshape arr))
in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr
#endif
@@ -87,9 +82,12 @@ newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed
-- these instances allow them to also be used as elements of arrays, thus
-- making them first-class in the API.
instance Elt a => Elt (Ranked n a) where
+ {-# INLINE mshape #-}
mshape (M_Ranked arr) = mshape arr
+ {-# INLINE mindex #-}
mindex (M_Ranked arr) i = Ranked (mindex arr i)
+ {-# INLINE mindexPartial #-}
mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a)
mindexPartial (M_Ranked arr) i =
coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
@@ -104,6 +102,7 @@ instance Elt a => Elt (Ranked n a) where
mtoListOuter (M_Ranked arr) =
coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr)
+ {-# INLINE mlift #-}
mlift :: forall sh1 sh2.
StaticShX sh2
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
@@ -112,6 +111,7 @@ instance Elt a => Elt (Ranked n a) where
coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $
mlift ssh2 f arr
+ {-# INLINE mlift2 #-}
mlift2 :: forall sh1 sh2 sh3.
StaticShX sh3
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
@@ -120,6 +120,7 @@ instance Elt a => Elt (Ranked n a) where
coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
mlift2 ssh3 f arr1 arr2
+ {-# INLINE mliftL #-}
mliftL :: forall sh1 sh2.
StaticShX sh2
-> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
@@ -139,7 +140,7 @@ instance Elt a => Elt (Ranked n a) where
type ShapeTree (Ranked n a) = (IShR n, ShapeTree a)
- mshapeTree (Ranked arr) = first shrFromShX2 (mshapeTree arr)
+ mshapeTree (Ranked arr) = first coerce (mshapeTree arr)
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
@@ -149,18 +150,19 @@ instance Elt a => Elt (Ranked n a) where
marrayStrides (M_Ranked arr) = marrayStrides arr
- mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
- mvecsWrite sh idx (Ranked arr) vecs =
- mvecsWrite sh idx arr
+ mvecsWriteLinear :: forall sh s. Int -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
+ mvecsWriteLinear idx (Ranked arr) vecs =
+ mvecsWriteLinear idx arr
(coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
vecs)
- mvecsWritePartial :: forall sh sh' s.
- IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
- -> MixedVecs s (sh ++ sh') (Ranked n a)
- -> ST s ()
- mvecsWritePartial sh idx arr vecs =
- mvecsWritePartial sh idx
+ mvecsWritePartialLinear
+ :: forall sh sh' s.
+ Proxy sh -> Int -> Mixed sh' (Ranked n a)
+ -> MixedVecs s (sh ++ sh') (Ranked n a)
+ -> ST s ()
+ mvecsWritePartialLinear proxy idx arr vecs =
+ mvecsWritePartialLinear proxy idx
(coerce @(Mixed sh' (Ranked n a))
@(Mixed sh' (Mixed (Replicate n Nothing) a))
arr)
@@ -176,6 +178,14 @@ instance Elt a => Elt (Ranked n a) where
(coerce @(MixedVecs s sh (Ranked n a))
@(MixedVecs s sh (Mixed (Replicate n Nothing) a))
vecs)
+ mvecsUnsafeFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
+ mvecsUnsafeFreeze sh vecs =
+ coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
+ @(Mixed sh (Ranked n a))
+ <$> mvecsUnsafeFreeze sh
+ (coerce @(MixedVecs s sh (Ranked n a))
+ @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a)
@@ -188,6 +198,10 @@ instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
| Dict <- lemKnownReplicate (SNat @n)
= MV_Ranked <$> mvecsUnsafeNew idx arr
+ mvecsReplicate idx (Ranked arr)
+ | Dict <- lemKnownReplicate (SNat @n)
+ = MV_Ranked <$> mvecsReplicate idx arr
+
mvecsNewEmpty _
| Dict <- lemKnownReplicate (SNat @n)
= MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
@@ -249,20 +263,9 @@ ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a
ratan2Array = liftRanked2 matan2Array
+{-# INLINE rshape #-}
rshape :: Elt a => Ranked n a -> IShR n
-rshape (Ranked arr) = shrFromShX2 (mshape arr)
+rshape (Ranked arr) = coerce (mshape arr)
rrank :: Elt a => Ranked n a -> SNat n
rrank = shrRank . rshape
-
--- Needed already here, but re-exported in Data.Array.Nested.Convert.
-shrFromShX :: forall sh. IShX sh -> IShR (Rank sh)
-shrFromShX ZSX = ZSR
-shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx
-
--- Needed already here, but re-exported in Data.Array.Nested.Convert.
--- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'.
-shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n
-shrFromShX2 sh
- | Refl <- lemRankReplicate (Proxy @n)
- = shrFromShX sh