From ad3f44c8b170298e63b8b57ee02cb88fbbd210fc Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sat, 15 Nov 2025 11:12:57 +0100 Subject: arith: Support Int8 and Int16 --- cbits/arith.c | 6 +++++- ops/Data/Array/Strided/Arith/Internal.hs | 30 ++++++++++++++++++++++++++ ops/Data/Array/Strided/Arith/Internal/Lists.hs | 4 +++- 3 files changed, 38 insertions(+), 2 deletions(-) diff --git a/cbits/arith.c b/cbits/arith.c index f19b01e..1066463 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -20,6 +20,8 @@ // Shorter names, due to CPP used both in function names and in C types. +typedef int8_t i8; +typedef int16_t i16; typedef int32_t i32; typedef int64_t i64; @@ -248,6 +250,8 @@ void oxarrays_stats_print_all(void) { #define GEN_ABS(x) \ _Generic((x), \ + i8: abs, \ + i16: abs, \ int: abs, \ long: labs, \ long long: llabs, \ @@ -738,7 +742,7 @@ enum redop_tag_t { * Generate all the functions * *****************************************************************************/ -#define INT_TYPES_XLIST X(i32) X(i64) +#define INT_TYPES_XLIST X(i8) X(i16) X(i32) X(i64) #define FLOAT_TYPES_XLIST X(double) X(float) #define NUM_TYPES_XLIST INT_TYPES_XLIST FLOAT_TYPES_XLIST diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs index 5802573..6364802 100644 --- a/ops/Data/Array/Strided/Arith/Internal.hs +++ b/ops/Data/Array/Strided/Arith/Internal.hs @@ -714,6 +714,36 @@ class NumElt a where numEltMaxIndex :: SNat n -> Array n a -> [Int] numEltDotprodInner :: SNat n -> Array (n + 1) a -> Array (n + 1) a -> Array n a +instance NumElt Int8 where + numEltAdd = addVectorInt8 + numEltSub = subVectorInt8 + numEltMul = mulVectorInt8 + numEltNeg = negVectorInt8 + numEltAbs = absVectorInt8 + numEltSignum = signumVectorInt8 + numEltSum1Inner = sum1VectorInt8 + numEltProduct1Inner = product1VectorInt8 + numEltSumFull = sumFullVectorInt8 + numEltProductFull = productFullVectorInt8 + numEltMinIndex _ = minindexVectorInt8 + numEltMaxIndex _ = maxindexVectorInt8 + numEltDotprodInner = dotprodinnerVectorInt8 + +instance NumElt Int16 where + numEltAdd = addVectorInt16 + numEltSub = subVectorInt16 + numEltMul = mulVectorInt16 + numEltNeg = negVectorInt16 + numEltAbs = absVectorInt16 + numEltSignum = signumVectorInt16 + numEltSum1Inner = sum1VectorInt16 + numEltProduct1Inner = product1VectorInt16 + numEltSumFull = sumFullVectorInt16 + numEltProductFull = productFullVectorInt16 + numEltMinIndex _ = minindexVectorInt16 + numEltMaxIndex _ = maxindexVectorInt16 + numEltDotprodInner = dotprodinnerVectorInt16 + instance NumElt Int32 where numEltAdd = addVectorInt32 numEltSub = subVectorInt32 diff --git a/ops/Data/Array/Strided/Arith/Internal/Lists.hs b/ops/Data/Array/Strided/Arith/Internal/Lists.hs index 910a77c..27204d2 100644 --- a/ops/Data/Array/Strided/Arith/Internal/Lists.hs +++ b/ops/Data/Array/Strided/Arith/Internal/Lists.hs @@ -16,7 +16,9 @@ data ArithType = ArithType intTypesList :: [ArithType] intTypesList = - [ArithType ''Int32 "i32" + [ArithType ''Int8 "i8" + ,ArithType ''Int16 "i16" + ,ArithType ''Int32 "i32" ,ArithType ''Int64 "i64" ] -- cgit v1.2.3-70-g09d2