From 3d48baae00c066f43fa2205b22f0357f069888f2 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 17 Jun 2024 12:51:19 +0200 Subject: Generalise more lifting functions --- src/Data/Array/Nested/Internal/Mixed.hs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'src/Data') diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index 647ea82..594383c 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -220,12 +220,14 @@ instance (Show a, Storable a) => Show (ShowViaPrimitive sh a) where . shows (coerce @[Primitive a] @[a] (mtoListLinear parr)) -mliftNumElt1 :: PrimElt a => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a) -> Mixed sh a -> Mixed sh a +mliftNumElt1 :: (PrimElt a, PrimElt b) + => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b) + -> Mixed sh a -> Mixed sh b mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr)) -mliftNumElt2 :: PrimElt a - => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a -> S.Array (Rank sh) a) - -> Mixed sh a -> Mixed sh a -> Mixed sh a +mliftNumElt2 :: (PrimElt a, PrimElt b, PrimElt c) + => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b -> S.Array (Rank sh) c) + -> Mixed sh a -> Mixed sh b -> Mixed sh c mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2)) | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2)) | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2 -- cgit v1.2.3-70-g09d2