From dc9db431fea6bee2aa6533b9df7dee44c002f252 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Mon, 20 May 2024 23:00:57 +0200
Subject: Some missing operations (rlift2, slift2)

---
 src/Data/Array/Nested.hs          |  5 +++--
 src/Data/Array/Nested/Internal.hs | 14 ++++++++++++++
 2 files changed, 17 insertions(+), 2 deletions(-)

diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 65f2b18..8d9601e 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -12,7 +12,7 @@ module Data.Array.Nested (
   rreplicate, rfromListOuter, rfromList1, rtoListOuter, rtoList1,
   rslice, rrev1, rreshape,
   -- ** Lifting orthotope operations to 'Ranked' arrays
-  rlift,
+  rlift, rlift2,
   -- ** Conversions
   rasXArrayPrim, rfromXArrayPrim,
   rcastToShaped,
@@ -29,7 +29,7 @@ module Data.Array.Nested (
   sreplicate, sfromListOuter, sfromList1, stoListOuter, stoList1,
   sslice, srev1, sreshape,
   -- ** Lifting orthotope operations to 'Shaped' arrays
-  slift,
+  slift, slift2,
   -- ** Conversions
   sasXArrayPrim, sfromXArrayPrim,
   stoRanked,
@@ -38,6 +38,7 @@ module Data.Array.Nested (
   Mixed,
   IxX(..), IIxX,
   KnownShX(..), StaticShX(..),
+  -- TODO: missing msumOuter1?
   mshape, mindex, mindexPartial, mgenerate,
   mtranspose, mappend, mscalar, mfromVector, mtoVector, munScalar,
   mrerank,
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index ab67dcc..e402f1e 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -1348,6 +1348,13 @@ rlift :: forall n1 n2 a. Elt a
       -> Ranked n1 a -> Ranked n2 a
 rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr)
 
+-- | See the documentation of 'mlift2'.
+rlift2 :: forall n1 n2 n3 a. Elt a
+       => SNat n3
+       -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b)
+       -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a
+rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2)
+
 rsumOuter1P :: forall n a.
                (Storable a, Num a)
             => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
@@ -1593,6 +1600,13 @@ slift :: forall sh1 sh2 a. Elt a
       -> Shaped sh1 a -> Shaped sh2 a
 slift sh2 f (Shaped arr) = Shaped (mlift (X.staticShapeFrom (shCvtSX sh2)) f arr)
 
+-- | See the documentation of 'mlift'.
+slift2 :: forall sh1 sh2 sh3 a. Elt a
+       => ShS sh3
+       -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b)
+       -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a
+slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (X.staticShapeFrom (shCvtSX sh3)) f arr1 arr2)
+
 ssumOuter1P :: forall sh n a. (Storable a, Num a)
             => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
 ssumOuter1P (Shaped (M_Primitive (SKnown sn :$% sh) arr)) =
-- 
cgit v1.2.3-70-g09d2