aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-06-17 12:51:19 +0200
committerTom Smeding <tom@tomsmeding.com>2024-06-17 12:51:19 +0200
commit3d48baae00c066f43fa2205b22f0357f069888f2 (patch)
tree0435abbf829eac178abfd24934e572594529d4f1
parent2ca90987058d14c79cd983ab14ee57949bae2871 (diff)
Generalise more lifting functions
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs10
1 files changed, 6 insertions, 4 deletions
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index 647ea82..594383c 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -220,12 +220,14 @@ instance (Show a, Storable a) => Show (ShowViaPrimitive sh a) where
. shows (coerce @[Primitive a] @[a] (mtoListLinear parr))
-mliftNumElt1 :: PrimElt a => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a) -> Mixed sh a -> Mixed sh a
+mliftNumElt1 :: (PrimElt a, PrimElt b)
+ => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b)
+ -> Mixed sh a -> Mixed sh b
mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr))
-mliftNumElt2 :: PrimElt a
- => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a -> S.Array (Rank sh) a)
- -> Mixed sh a -> Mixed sh a -> Mixed sh a
+mliftNumElt2 :: (PrimElt a, PrimElt b, PrimElt c)
+ => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b -> S.Array (Rank sh) c)
+ -> Mixed sh a -> Mixed sh b -> Mixed sh c
mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2))
| sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2))
| otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2