diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-13 09:26:20 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-13 09:28:24 +0100 |
commit | a87c80b1fbaa826142605d0846479c94d6ee2bcc (patch) | |
tree | 902faed123c7b363170726ca1d9adb2211eca4ac | |
parent | 4ec47c712e6809bb7ed839055d1ac008cf500f50 (diff) |
Add atan2
-rw-r--r-- | cbits/arith.c | 5 | ||||
-rw-r--r-- | cbits/arith_lists.h | 1 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 3 | ||||
-rw-r--r-- | src/Data/Array/Nested.hs | 6 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 3 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 3 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 3 |
7 files changed, 21 insertions, 3 deletions
diff --git a/cbits/arith.c b/cbits/arith.c index db88588..c984255 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -46,6 +46,7 @@ typedef int64_t i64; #define GEN_POW(x, y) _Generic((x), float: powf, double: pow)(x, y) #define GEN_LOGBASE(x, y) _Generic((x), float: logf(y) / logf(x), double: log(y) / log(x)) +#define GEN_ATAN2(y, x) _Generic((x), float: atan2f(y, x), double: atan2(y, x)) #define GEN_EXP(x) _Generic((x), float: expf, double: exp)(x) #define GEN_LOG(x) _Generic((x), float: logf, double: log)(x) #define GEN_SQRT(x) _Generic((x), float: sqrtf, double: sqrt)(x) @@ -456,6 +457,7 @@ enum fbinop_tag_t { case FB_DIV: oxarop_op_fdiv_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ case FB_POW: oxarop_op_pow_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ + case FB_ATAN2: oxarop_op_atan2_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ default: wrong_op("fbinary_sv_strided", tag); \ } \ } \ @@ -464,6 +466,7 @@ enum fbinop_tag_t { case FB_DIV: oxarop_op_fdiv_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ case FB_POW: oxarop_op_pow_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ + case FB_ATAN2: oxarop_op_atan2_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ default: wrong_op("fbinary_vs_strided", tag); \ } \ } \ @@ -472,6 +475,7 @@ enum fbinop_tag_t { case FB_DIV: oxarop_op_fdiv_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ case FB_POW: oxarop_op_pow_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ + case FB_ATAN2: oxarop_op_atan2_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ default: wrong_op("fbinary_vv_strided", tag); \ } \ } @@ -597,6 +601,7 @@ INT_TYPES_XLIST NONCOMM_OP_STRIDED(fdiv, /, typ) \ PREFIX_BINOP_STRIDED(pow, GEN_POW, typ) \ PREFIX_BINOP_STRIDED(logbase, GEN_LOGBASE, typ) \ + PREFIX_BINOP_STRIDED(atan2, GEN_ATAN2, typ) \ UNARY_OP_STRIDED(recip, 1.0/, typ) \ UNARY_OP_STRIDED(exp, GEN_EXP, typ) \ UNARY_OP_STRIDED(log, GEN_LOG, typ) \ diff --git a/cbits/arith_lists.h b/cbits/arith_lists.h index b651b62..432765c 100644 --- a/cbits/arith_lists.h +++ b/cbits/arith_lists.h @@ -8,6 +8,7 @@ LIST_IBINOP(IB_REM, 2, rem) LIST_FBINOP(FB_DIV, 1, /) LIST_FBINOP(FB_POW, 2, **) LIST_FBINOP(FB_LOGBASE, 3, logBase) +LIST_FBINOP(FB_ATAN2, 4, atan2) LIST_UNOP(UO_NEG, 1,) LIST_UNOP(UO_ABS, 2,) diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 11cbba6..c940914 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -860,6 +860,7 @@ class NumElt a => FloatElt a where floatEltExpm1 :: SNat n -> RS.Array n a -> RS.Array n a floatEltLog1pexp :: SNat n -> RS.Array n a -> RS.Array n a floatEltLog1mexp :: SNat n -> RS.Array n a -> RS.Array n a + floatEltAtan2 :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a instance FloatElt Float where floatEltDiv = divVectorFloat @@ -885,6 +886,7 @@ instance FloatElt Float where floatEltExpm1 = expm1VectorFloat floatEltLog1pexp = log1pexpVectorFloat floatEltLog1mexp = log1mexpVectorFloat + floatEltAtan2 = atan2VectorFloat instance FloatElt Double where floatEltDiv = divVectorDouble @@ -910,3 +912,4 @@ instance FloatElt Double where floatEltExpm1 = expm1VectorDouble floatEltLog1pexp = log1pexpVectorDouble floatEltLog1mexp = log1mexpVectorDouble + floatEltAtan2 = atan2VectorDouble diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 58c0c71..4e5acb4 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -24,7 +24,7 @@ module Data.Array.Nested ( -- ** Additional arithmetic operations -- -- $integralRealFloat - rquotArray, rremArray, + rquotArray, rremArray, ratan2Array, -- * Shaped arrays Shaped(Shaped), @@ -50,7 +50,7 @@ module Data.Array.Nested ( -- ** Additional arithmetic operations -- -- $integralRealFloat - squotArray, sremArray, + squotArray, sremArray, satan2Array, -- * Mixed arrays Mixed, @@ -79,7 +79,7 @@ module Data.Array.Nested ( -- ** Additional arithmetic operations -- -- $integralRealFloat - mquotArray, mremArray, + mquotArray, mremArray, matan2Array, -- * Array elements Elt, diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index 7e1f100..80d581e 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -269,6 +269,9 @@ mquotArray, mremArray :: (IntElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mi mquotArray = mliftNumElt2 intEltQuot mremArray = mliftNumElt2 intEltRem +matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a +matan2Array = mliftNumElt2 floatEltAtan2 + -- | Allowable element types in a mixed array, and by extension in a 'Ranked' or -- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index 4fb29e0..9493bc6 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -240,6 +240,9 @@ rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ra rquotArray = arithPromoteRanked2 mquotArray rremArray = arithPromoteRanked2 mremArray +ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a +ratan2Array = arithPromoteRanked2 matan2Array + remptyArray :: KnownElt a => Ranked 1 a remptyArray = mtoRanked (memptyArray ZSX) diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index ed616cf..03631b0 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -238,6 +238,9 @@ squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> squotArray = arithPromoteShaped2 mquotArray sremArray = arithPromoteShaped2 mremArray +satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a +satan2Array = arithPromoteShaped2 matan2Array + semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a semptyArray sh = Shaped (memptyArray (shCvtSX sh)) |