diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-03-13 09:26:20 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-13 09:28:24 +0100 | 
| commit | a87c80b1fbaa826142605d0846479c94d6ee2bcc (patch) | |
| tree | 902faed123c7b363170726ca1d9adb2211eca4ac | |
| parent | 4ec47c712e6809bb7ed839055d1ac008cf500f50 (diff) | |
Add atan2
| -rw-r--r-- | cbits/arith.c | 5 | ||||
| -rw-r--r-- | cbits/arith_lists.h | 1 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 3 | ||||
| -rw-r--r-- | src/Data/Array/Nested.hs | 6 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 3 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 3 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 3 | 
7 files changed, 21 insertions, 3 deletions
| diff --git a/cbits/arith.c b/cbits/arith.c index db88588..c984255 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -46,6 +46,7 @@ typedef int64_t i64;  #define GEN_POW(x, y) _Generic((x), float: powf, double: pow)(x, y)  #define GEN_LOGBASE(x, y) _Generic((x), float: logf(y) / logf(x), double: log(y) / log(x)) +#define GEN_ATAN2(y, x) _Generic((x), float: atan2f(y, x), double: atan2(y, x))  #define GEN_EXP(x) _Generic((x), float: expf, double: exp)(x)  #define GEN_LOG(x) _Generic((x), float: logf, double: log)(x)  #define GEN_SQRT(x) _Generic((x), float: sqrtf, double: sqrt)(x) @@ -456,6 +457,7 @@ enum fbinop_tag_t {        case FB_DIV: oxarop_op_fdiv_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \        case FB_POW: oxarop_op_pow_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \        case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ +      case FB_ATAN2: oxarop_op_atan2_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \        default: wrong_op("fbinary_sv_strided", tag); \      } \    } \ @@ -464,6 +466,7 @@ enum fbinop_tag_t {        case FB_DIV: oxarop_op_fdiv_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \        case FB_POW: oxarop_op_pow_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \        case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ +      case FB_ATAN2: oxarop_op_atan2_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \        default: wrong_op("fbinary_vs_strided", tag); \      } \    } \ @@ -472,6 +475,7 @@ enum fbinop_tag_t {        case FB_DIV: oxarop_op_fdiv_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \        case FB_POW: oxarop_op_pow_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \        case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ +      case FB_ATAN2: oxarop_op_atan2_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \        default: wrong_op("fbinary_vv_strided", tag); \      } \    } @@ -597,6 +601,7 @@ INT_TYPES_XLIST    NONCOMM_OP_STRIDED(fdiv, /, typ) \    PREFIX_BINOP_STRIDED(pow, GEN_POW, typ) \    PREFIX_BINOP_STRIDED(logbase, GEN_LOGBASE, typ) \ +  PREFIX_BINOP_STRIDED(atan2, GEN_ATAN2, typ) \    UNARY_OP_STRIDED(recip, 1.0/, typ) \    UNARY_OP_STRIDED(exp, GEN_EXP, typ) \    UNARY_OP_STRIDED(log, GEN_LOG, typ) \ diff --git a/cbits/arith_lists.h b/cbits/arith_lists.h index b651b62..432765c 100644 --- a/cbits/arith_lists.h +++ b/cbits/arith_lists.h @@ -8,6 +8,7 @@ LIST_IBINOP(IB_REM, 2, rem)  LIST_FBINOP(FB_DIV, 1, /)  LIST_FBINOP(FB_POW, 2, **)  LIST_FBINOP(FB_LOGBASE, 3, logBase) +LIST_FBINOP(FB_ATAN2, 4, atan2)  LIST_UNOP(UO_NEG, 1,)  LIST_UNOP(UO_ABS, 2,) diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 11cbba6..c940914 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -860,6 +860,7 @@ class NumElt a => FloatElt a where    floatEltExpm1 :: SNat n -> RS.Array n a -> RS.Array n a    floatEltLog1pexp :: SNat n -> RS.Array n a -> RS.Array n a    floatEltLog1mexp :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltAtan2 :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a  instance FloatElt Float where    floatEltDiv = divVectorFloat @@ -885,6 +886,7 @@ instance FloatElt Float where    floatEltExpm1 = expm1VectorFloat    floatEltLog1pexp = log1pexpVectorFloat    floatEltLog1mexp = log1mexpVectorFloat +  floatEltAtan2 = atan2VectorFloat  instance FloatElt Double where    floatEltDiv = divVectorDouble @@ -910,3 +912,4 @@ instance FloatElt Double where    floatEltExpm1 = expm1VectorDouble    floatEltLog1pexp = log1pexpVectorDouble    floatEltLog1mexp = log1mexpVectorDouble +  floatEltAtan2 = atan2VectorDouble diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 58c0c71..4e5acb4 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -24,7 +24,7 @@ module Data.Array.Nested (    -- ** Additional arithmetic operations    --    -- $integralRealFloat -  rquotArray, rremArray, +  rquotArray, rremArray, ratan2Array,    -- * Shaped arrays    Shaped(Shaped), @@ -50,7 +50,7 @@ module Data.Array.Nested (    -- ** Additional arithmetic operations    --    -- $integralRealFloat -  squotArray, sremArray, +  squotArray, sremArray, satan2Array,    -- * Mixed arrays    Mixed, @@ -79,7 +79,7 @@ module Data.Array.Nested (    -- ** Additional arithmetic operations    --    -- $integralRealFloat -  mquotArray, mremArray, +  mquotArray, mremArray, matan2Array,    -- * Array elements    Elt, diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index 7e1f100..80d581e 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -269,6 +269,9 @@ mquotArray, mremArray :: (IntElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mi  mquotArray = mliftNumElt2 intEltQuot  mremArray = mliftNumElt2 intEltRem +matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a +matan2Array = mliftNumElt2 floatEltAtan2 +  -- | Allowable element types in a mixed array, and by extension in a 'Ranked' or  -- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index 4fb29e0..9493bc6 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -240,6 +240,9 @@ rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ra  rquotArray = arithPromoteRanked2 mquotArray  rremArray = arithPromoteRanked2 mremArray +ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a +ratan2Array = arithPromoteRanked2 matan2Array +  remptyArray :: KnownElt a => Ranked 1 a  remptyArray = mtoRanked (memptyArray ZSX) diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index ed616cf..03631b0 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -238,6 +238,9 @@ squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a ->  squotArray = arithPromoteShaped2 mquotArray  sremArray = arithPromoteShaped2 mremArray +satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a +satan2Array = arithPromoteShaped2 matan2Array +  semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a  semptyArray sh = Shaped (memptyArray (shCvtSX sh)) | 
