aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-13 09:26:20 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-13 09:28:24 +0100
commita87c80b1fbaa826142605d0846479c94d6ee2bcc (patch)
tree902faed123c7b363170726ca1d9adb2211eca4ac
parent4ec47c712e6809bb7ed839055d1ac008cf500f50 (diff)
Add atan2
-rw-r--r--cbits/arith.c5
-rw-r--r--cbits/arith_lists.h1
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs3
-rw-r--r--src/Data/Array/Nested.hs6
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs3
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs3
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs3
7 files changed, 21 insertions, 3 deletions
diff --git a/cbits/arith.c b/cbits/arith.c
index db88588..c984255 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -46,6 +46,7 @@ typedef int64_t i64;
#define GEN_POW(x, y) _Generic((x), float: powf, double: pow)(x, y)
#define GEN_LOGBASE(x, y) _Generic((x), float: logf(y) / logf(x), double: log(y) / log(x))
+#define GEN_ATAN2(y, x) _Generic((x), float: atan2f(y, x), double: atan2(y, x))
#define GEN_EXP(x) _Generic((x), float: expf, double: exp)(x)
#define GEN_LOG(x) _Generic((x), float: logf, double: log)(x)
#define GEN_SQRT(x) _Generic((x), float: sqrtf, double: sqrt)(x)
@@ -456,6 +457,7 @@ enum fbinop_tag_t {
case FB_DIV: oxarop_op_fdiv_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
case FB_POW: oxarop_op_pow_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
+ case FB_ATAN2: oxarop_op_atan2_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
default: wrong_op("fbinary_sv_strided", tag); \
} \
} \
@@ -464,6 +466,7 @@ enum fbinop_tag_t {
case FB_DIV: oxarop_op_fdiv_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
case FB_POW: oxarop_op_pow_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
+ case FB_ATAN2: oxarop_op_atan2_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
default: wrong_op("fbinary_vs_strided", tag); \
} \
} \
@@ -472,6 +475,7 @@ enum fbinop_tag_t {
case FB_DIV: oxarop_op_fdiv_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
case FB_POW: oxarop_op_pow_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
+ case FB_ATAN2: oxarop_op_atan2_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
default: wrong_op("fbinary_vv_strided", tag); \
} \
}
@@ -597,6 +601,7 @@ INT_TYPES_XLIST
NONCOMM_OP_STRIDED(fdiv, /, typ) \
PREFIX_BINOP_STRIDED(pow, GEN_POW, typ) \
PREFIX_BINOP_STRIDED(logbase, GEN_LOGBASE, typ) \
+ PREFIX_BINOP_STRIDED(atan2, GEN_ATAN2, typ) \
UNARY_OP_STRIDED(recip, 1.0/, typ) \
UNARY_OP_STRIDED(exp, GEN_EXP, typ) \
UNARY_OP_STRIDED(log, GEN_LOG, typ) \
diff --git a/cbits/arith_lists.h b/cbits/arith_lists.h
index b651b62..432765c 100644
--- a/cbits/arith_lists.h
+++ b/cbits/arith_lists.h
@@ -8,6 +8,7 @@ LIST_IBINOP(IB_REM, 2, rem)
LIST_FBINOP(FB_DIV, 1, /)
LIST_FBINOP(FB_POW, 2, **)
LIST_FBINOP(FB_LOGBASE, 3, logBase)
+LIST_FBINOP(FB_ATAN2, 4, atan2)
LIST_UNOP(UO_NEG, 1,)
LIST_UNOP(UO_ABS, 2,)
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index 11cbba6..c940914 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -860,6 +860,7 @@ class NumElt a => FloatElt a where
floatEltExpm1 :: SNat n -> RS.Array n a -> RS.Array n a
floatEltLog1pexp :: SNat n -> RS.Array n a -> RS.Array n a
floatEltLog1mexp :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltAtan2 :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
instance FloatElt Float where
floatEltDiv = divVectorFloat
@@ -885,6 +886,7 @@ instance FloatElt Float where
floatEltExpm1 = expm1VectorFloat
floatEltLog1pexp = log1pexpVectorFloat
floatEltLog1mexp = log1mexpVectorFloat
+ floatEltAtan2 = atan2VectorFloat
instance FloatElt Double where
floatEltDiv = divVectorDouble
@@ -910,3 +912,4 @@ instance FloatElt Double where
floatEltExpm1 = expm1VectorDouble
floatEltLog1pexp = log1pexpVectorDouble
floatEltLog1mexp = log1mexpVectorDouble
+ floatEltAtan2 = atan2VectorDouble
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 58c0c71..4e5acb4 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -24,7 +24,7 @@ module Data.Array.Nested (
-- ** Additional arithmetic operations
--
-- $integralRealFloat
- rquotArray, rremArray,
+ rquotArray, rremArray, ratan2Array,
-- * Shaped arrays
Shaped(Shaped),
@@ -50,7 +50,7 @@ module Data.Array.Nested (
-- ** Additional arithmetic operations
--
-- $integralRealFloat
- squotArray, sremArray,
+ squotArray, sremArray, satan2Array,
-- * Mixed arrays
Mixed,
@@ -79,7 +79,7 @@ module Data.Array.Nested (
-- ** Additional arithmetic operations
--
-- $integralRealFloat
- mquotArray, mremArray,
+ mquotArray, mremArray, matan2Array,
-- * Array elements
Elt,
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index 7e1f100..80d581e 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -269,6 +269,9 @@ mquotArray, mremArray :: (IntElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mi
mquotArray = mliftNumElt2 intEltQuot
mremArray = mliftNumElt2 intEltRem
+matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a
+matan2Array = mliftNumElt2 floatEltAtan2
+
-- | Allowable element types in a mixed array, and by extension in a 'Ranked' or
-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index 4fb29e0..9493bc6 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -240,6 +240,9 @@ rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ra
rquotArray = arithPromoteRanked2 mquotArray
rremArray = arithPromoteRanked2 mremArray
+ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a
+ratan2Array = arithPromoteRanked2 matan2Array
+
remptyArray :: KnownElt a => Ranked 1 a
remptyArray = mtoRanked (memptyArray ZSX)
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index ed616cf..03631b0 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -238,6 +238,9 @@ squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a ->
squotArray = arithPromoteShaped2 mquotArray
sremArray = arithPromoteShaped2 mremArray
+satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a
+satan2Array = arithPromoteShaped2 matan2Array
+
semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a
semptyArray sh = Shaped (memptyArray (shCvtSX sh))