From 52c0237fbdbc3c99ee6565ba18250360a330fb8b Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Mon, 20 May 2024 17:21:21 +0200
Subject: Rerank on primitive arrays

---
 src/Data/Array/Mixed.hs           | 10 ++++++
 src/Data/Array/Nested.hs          |  3 ++
 src/Data/Array/Nested/Internal.hs | 68 ++++++++++++++++++++++++++++++++++++++-
 3 files changed, 80 insertions(+), 1 deletion(-)

diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 33c0dd6..2f23903 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -320,12 +320,22 @@ listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short
 ixDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i
 ixDrop = coerce (listxDrop @(Const i) @(Const i))
 
+shDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i
+shDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat))
+
 shDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i
 shDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j))
 
 shDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i
 shDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat))
 
+shTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i
+shTakeSSX _ = flip go
+  where
+    go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i
+    go ZKX _ = ZSX
+    go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh
+
 ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh'
 ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i))
 
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 51754d0..2208349 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -8,6 +8,7 @@ module Data.Array.Nested (
   ShR(.., ZSR, (:$:)),
   rshape, rindex, rindexPartial, rgenerate, rsumOuter1,
   rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar,
+  rrerank,
   rreplicate, rfromList, rfromList1, rtoList, rtoList1,
   rslice, rrev1, rreshape,
   -- ** Lifting orthotope operations to 'Ranked' arrays
@@ -23,6 +24,7 @@ module Data.Array.Nested (
   ShS(.., ZSS, (:$$)), KnownShS(..),
   sshape, sindex, sindexPartial, sgenerate, ssumOuter1,
   stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar,
+  srerank,
   sreplicate, sfromList, sfromList1, stoList, stoList1,
   sslice, srev1, sreshape,
   -- ** Lifting orthotope operations to 'Shaped' arrays
@@ -36,6 +38,7 @@ module Data.Array.Nested (
   IxX(..), IIxX,
   KnownShX(..), StaticShX(..),
   mgenerate, mtranspose, mappend, mfromVector, mtoVector, munScalar,
+  mrerank,
   mreplicate, mfromList, mtoList, mslice, mrev1, mreshape,
   -- ** Conversions
   masXArrayPrim, mfromXArrayPrim,
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 99c4a46..badb910 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -28,7 +28,6 @@
 
 {-|
 TODO:
-* Write `rerank`
 * Write `rconst :: OR.Array n a -> Ranked n a`
 
 -}
@@ -900,6 +899,24 @@ mtoList = map munScalar . mtoList1
 munScalar :: Elt a => Mixed '[] a -> a
 munScalar arr = mindex arr ZIX
 
+mrerankP :: forall sh1 sh2 sh a. Storable a
+         => StaticShX sh -> IShX sh2
+         -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a))
+         -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive a)
+mrerankP ssh sh2 f (M_Primitive sh arr) =
+  let sh1 = shDropSSX sh ssh
+  in M_Primitive (X.shAppend (shTakeSSX (Proxy @sh1) sh ssh) sh2)
+                 (X.rerank ssh (X.staticShapeFrom sh1) (X.staticShapeFrom sh2)
+                           (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
+                           arr)
+
+mrerank :: forall sh1 sh2 sh a. (Storable a, PrimElt a)
+        => StaticShX sh -> IShX sh2
+        -> (Mixed sh1 a -> Mixed sh2 a)
+        -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) a
+mrerank ssh sh2 f (toPrimitive -> arr) =
+  fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr
+
 mreplicateP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
 mreplicateP sh x = M_Primitive sh (X.replicate sh x)
 
@@ -1389,6 +1406,24 @@ rtoList1 = map runScalar . rtoList
 runScalar :: Elt a => Ranked 0 a -> a
 runScalar arr = rindex arr ZIR
 
+rrerankP :: forall n1 n2 n a. Storable a
+         => SNat n -> IShR n2
+         -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive a))
+         -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive a)
+rrerankP sn sh2 f (Ranked arr)
+  | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat))
+  , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat))
+  = Ranked (mrerankP (ssxFromSNat sn) (shCvtRX sh2)
+                     (\a -> let Ranked r = f (Ranked a) in r)
+                     arr)
+
+rrerank :: forall n1 n2 n a. (Storable a, PrimElt a)
+         => SNat n -> IShR n2
+         -> (Ranked n1 a -> Ranked n2 a)
+         -> Ranked (n + n1) a -> Ranked (n + n2) a
+rrerank ssh sh2 f (rtoPrimitive -> arr) =
+  rfromPrimitive $ rrerankP ssh sh2 (rtoPrimitive . f . rfromPrimitive) arr
+
 rreplicateP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
 rreplicateP sh x
   | Dict <- lemKnownReplicate (snatFromShR sh)
@@ -1438,6 +1473,12 @@ rcastToShaped (Ranked arr) targetsh
   , Refl <- lemRankMapJust targetsh
   = mcastToShaped arr targetsh
 
+rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a
+rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr)
+
+rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a)
+rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr)
+
 
 -- ====== API OF SHAPED ARRAYS ====== --
 
@@ -1619,6 +1660,25 @@ stoList1 = map sunScalar . stoList
 sunScalar :: Elt a => Shaped '[] a -> a
 sunScalar arr = sindex arr ZIS
 
+srerankP :: forall sh1 sh2 sh a. Storable a
+         => ShS sh -> ShS sh2
+         -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive a))
+         -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive a)
+srerankP sh sh2 f sarr@(Shaped arr)
+  | Refl <- lemCommMapJustApp sh (Proxy @sh1)
+  , Refl <- lemCommMapJustApp sh (Proxy @sh2)
+  = Shaped (mrerankP (X.staticShapeFrom (shTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (X.staticShapeFrom (shCvtSX sh))))
+                     (shCvtSX sh2)
+                     (\a -> let Shaped r = f (Shaped a) in r)
+                     arr)
+
+srerank :: forall sh1 sh2 sh a. (Storable a, PrimElt a)
+        => StaticShX sh -> IShX sh2
+        -> (Mixed sh1 a -> Mixed sh2 a)
+        -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) a
+srerank ssh sh2 f (toPrimitive -> arr) =
+  fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr
+
 sreplicateP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
 sreplicateP sh x = Shaped (mreplicateP (shCvtSX sh) x)
 
@@ -1652,3 +1712,9 @@ stoRanked :: Elt a => Shaped sh a -> Ranked (X.Rank sh) a
 stoRanked sarr@(Shaped arr)
   | Refl <- lemRankMapJust (sshape sarr)
   = mtoRanked arr
+
+sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a
+sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr)
+
+stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a)
+stoPrimitive (Shaped arr) = Shaped (toPrimitive arr)
-- 
cgit v1.2.3-70-g09d2