aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-15 11:12:57 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-15 11:13:10 +0100
commitad3f44c8b170298e63b8b57ee02cb88fbbd210fc (patch)
tree60d4457b26640cb3dd016d3ac419f654ecf06a8e
parent11db054607e476c68be5681d99d96630642637e6 (diff)
arith: Support Int8 and Int16
-rw-r--r--cbits/arith.c6
-rw-r--r--ops/Data/Array/Strided/Arith/Internal.hs30
-rw-r--r--ops/Data/Array/Strided/Arith/Internal/Lists.hs4
3 files changed, 38 insertions, 2 deletions
diff --git a/cbits/arith.c b/cbits/arith.c
index f19b01e..1066463 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -20,6 +20,8 @@
// Shorter names, due to CPP used both in function names and in C types.
+typedef int8_t i8;
+typedef int16_t i16;
typedef int32_t i32;
typedef int64_t i64;
@@ -248,6 +250,8 @@ void oxarrays_stats_print_all(void) {
#define GEN_ABS(x) \
_Generic((x), \
+ i8: abs, \
+ i16: abs, \
int: abs, \
long: labs, \
long long: llabs, \
@@ -738,7 +742,7 @@ enum redop_tag_t {
* Generate all the functions *
*****************************************************************************/
-#define INT_TYPES_XLIST X(i32) X(i64)
+#define INT_TYPES_XLIST X(i8) X(i16) X(i32) X(i64)
#define FLOAT_TYPES_XLIST X(double) X(float)
#define NUM_TYPES_XLIST INT_TYPES_XLIST FLOAT_TYPES_XLIST
diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs
index 5802573..6364802 100644
--- a/ops/Data/Array/Strided/Arith/Internal.hs
+++ b/ops/Data/Array/Strided/Arith/Internal.hs
@@ -714,6 +714,36 @@ class NumElt a where
numEltMaxIndex :: SNat n -> Array n a -> [Int]
numEltDotprodInner :: SNat n -> Array (n + 1) a -> Array (n + 1) a -> Array n a
+instance NumElt Int8 where
+ numEltAdd = addVectorInt8
+ numEltSub = subVectorInt8
+ numEltMul = mulVectorInt8
+ numEltNeg = negVectorInt8
+ numEltAbs = absVectorInt8
+ numEltSignum = signumVectorInt8
+ numEltSum1Inner = sum1VectorInt8
+ numEltProduct1Inner = product1VectorInt8
+ numEltSumFull = sumFullVectorInt8
+ numEltProductFull = productFullVectorInt8
+ numEltMinIndex _ = minindexVectorInt8
+ numEltMaxIndex _ = maxindexVectorInt8
+ numEltDotprodInner = dotprodinnerVectorInt8
+
+instance NumElt Int16 where
+ numEltAdd = addVectorInt16
+ numEltSub = subVectorInt16
+ numEltMul = mulVectorInt16
+ numEltNeg = negVectorInt16
+ numEltAbs = absVectorInt16
+ numEltSignum = signumVectorInt16
+ numEltSum1Inner = sum1VectorInt16
+ numEltProduct1Inner = product1VectorInt16
+ numEltSumFull = sumFullVectorInt16
+ numEltProductFull = productFullVectorInt16
+ numEltMinIndex _ = minindexVectorInt16
+ numEltMaxIndex _ = maxindexVectorInt16
+ numEltDotprodInner = dotprodinnerVectorInt16
+
instance NumElt Int32 where
numEltAdd = addVectorInt32
numEltSub = subVectorInt32
diff --git a/ops/Data/Array/Strided/Arith/Internal/Lists.hs b/ops/Data/Array/Strided/Arith/Internal/Lists.hs
index 910a77c..27204d2 100644
--- a/ops/Data/Array/Strided/Arith/Internal/Lists.hs
+++ b/ops/Data/Array/Strided/Arith/Internal/Lists.hs
@@ -16,7 +16,9 @@ data ArithType = ArithType
intTypesList :: [ArithType]
intTypesList =
- [ArithType ''Int32 "i32"
+ [ArithType ''Int8 "i8"
+ ,ArithType ''Int16 "i16"
+ ,ArithType ''Int32 "i32"
,ArithType ''Int64 "i64"
]