aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-20 17:21:21 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-20 17:21:35 +0200
commit52c0237fbdbc3c99ee6565ba18250360a330fb8b (patch)
treee3b7b11d81f557dfbff9df043198e9e7c50fb569 /src/Data/Array/Nested
parent16e52d87e9955628a016946c10515c39ce4ef1d0 (diff)
Rerank on primitive arrays
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r--src/Data/Array/Nested/Internal.hs68
1 files changed, 67 insertions, 1 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 99c4a46..badb910 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -28,7 +28,6 @@
{-|
TODO:
-* Write `rerank`
* Write `rconst :: OR.Array n a -> Ranked n a`
-}
@@ -900,6 +899,24 @@ mtoList = map munScalar . mtoList1
munScalar :: Elt a => Mixed '[] a -> a
munScalar arr = mindex arr ZIX
+mrerankP :: forall sh1 sh2 sh a. Storable a
+ => StaticShX sh -> IShX sh2
+ -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a))
+ -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive a)
+mrerankP ssh sh2 f (M_Primitive sh arr) =
+ let sh1 = shDropSSX sh ssh
+ in M_Primitive (X.shAppend (shTakeSSX (Proxy @sh1) sh ssh) sh2)
+ (X.rerank ssh (X.staticShapeFrom sh1) (X.staticShapeFrom sh2)
+ (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
+ arr)
+
+mrerank :: forall sh1 sh2 sh a. (Storable a, PrimElt a)
+ => StaticShX sh -> IShX sh2
+ -> (Mixed sh1 a -> Mixed sh2 a)
+ -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) a
+mrerank ssh sh2 f (toPrimitive -> arr) =
+ fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr
+
mreplicateP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
mreplicateP sh x = M_Primitive sh (X.replicate sh x)
@@ -1389,6 +1406,24 @@ rtoList1 = map runScalar . rtoList
runScalar :: Elt a => Ranked 0 a -> a
runScalar arr = rindex arr ZIR
+rrerankP :: forall n1 n2 n a. Storable a
+ => SNat n -> IShR n2
+ -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive a))
+ -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive a)
+rrerankP sn sh2 f (Ranked arr)
+ | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat))
+ , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat))
+ = Ranked (mrerankP (ssxFromSNat sn) (shCvtRX sh2)
+ (\a -> let Ranked r = f (Ranked a) in r)
+ arr)
+
+rrerank :: forall n1 n2 n a. (Storable a, PrimElt a)
+ => SNat n -> IShR n2
+ -> (Ranked n1 a -> Ranked n2 a)
+ -> Ranked (n + n1) a -> Ranked (n + n2) a
+rrerank ssh sh2 f (rtoPrimitive -> arr) =
+ rfromPrimitive $ rrerankP ssh sh2 (rtoPrimitive . f . rfromPrimitive) arr
+
rreplicateP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
rreplicateP sh x
| Dict <- lemKnownReplicate (snatFromShR sh)
@@ -1438,6 +1473,12 @@ rcastToShaped (Ranked arr) targetsh
, Refl <- lemRankMapJust targetsh
= mcastToShaped arr targetsh
+rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a
+rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr)
+
+rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a)
+rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr)
+
-- ====== API OF SHAPED ARRAYS ====== --
@@ -1619,6 +1660,25 @@ stoList1 = map sunScalar . stoList
sunScalar :: Elt a => Shaped '[] a -> a
sunScalar arr = sindex arr ZIS
+srerankP :: forall sh1 sh2 sh a. Storable a
+ => ShS sh -> ShS sh2
+ -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive a))
+ -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive a)
+srerankP sh sh2 f sarr@(Shaped arr)
+ | Refl <- lemCommMapJustApp sh (Proxy @sh1)
+ , Refl <- lemCommMapJustApp sh (Proxy @sh2)
+ = Shaped (mrerankP (X.staticShapeFrom (shTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (X.staticShapeFrom (shCvtSX sh))))
+ (shCvtSX sh2)
+ (\a -> let Shaped r = f (Shaped a) in r)
+ arr)
+
+srerank :: forall sh1 sh2 sh a. (Storable a, PrimElt a)
+ => StaticShX sh -> IShX sh2
+ -> (Mixed sh1 a -> Mixed sh2 a)
+ -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) a
+srerank ssh sh2 f (toPrimitive -> arr) =
+ fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr
+
sreplicateP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
sreplicateP sh x = Shaped (mreplicateP (shCvtSX sh) x)
@@ -1652,3 +1712,9 @@ stoRanked :: Elt a => Shaped sh a -> Ranked (X.Rank sh) a
stoRanked sarr@(Shaped arr)
| Refl <- lemRankMapJust (sshape sarr)
= mtoRanked arr
+
+sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a
+sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr)
+
+stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a)
+stoPrimitive (Shaped arr) = Shaped (toPrimitive arr)