From 3d48baae00c066f43fa2205b22f0357f069888f2 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Mon, 17 Jun 2024 12:51:19 +0200
Subject: Generalise more lifting functions

---
 src/Data/Array/Nested/Internal/Mixed.hs | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

(limited to 'src/Data')

diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index 647ea82..594383c 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -220,12 +220,14 @@ instance (Show a, Storable a) => Show (ShowViaPrimitive sh a) where
       . shows (coerce @[Primitive a] @[a] (mtoListLinear parr))
 
 
-mliftNumElt1 :: PrimElt a => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a) -> Mixed sh a -> Mixed sh a
+mliftNumElt1 :: (PrimElt a, PrimElt b)
+             => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b)
+             -> Mixed sh a -> Mixed sh b
 mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr))
 
-mliftNumElt2 :: PrimElt a
-             => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a -> S.Array (Rank sh) a)
-             -> Mixed sh a -> Mixed sh a -> Mixed sh a
+mliftNumElt2 :: (PrimElt a, PrimElt b, PrimElt c)
+             => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b -> S.Array (Rank sh) c)
+             -> Mixed sh a -> Mixed sh b -> Mixed sh c
 mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2))
   | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2))
   | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2
-- 
cgit v1.2.3-70-g09d2