aboutsummaryrefslogtreecommitdiff
path: root/ops/Data/Array
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-15 11:12:57 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-15 11:13:10 +0100
commitad3f44c8b170298e63b8b57ee02cb88fbbd210fc (patch)
tree60d4457b26640cb3dd016d3ac419f654ecf06a8e /ops/Data/Array
parent11db054607e476c68be5681d99d96630642637e6 (diff)
arith: Support Int8 and Int16
Diffstat (limited to 'ops/Data/Array')
-rw-r--r--ops/Data/Array/Strided/Arith/Internal.hs30
-rw-r--r--ops/Data/Array/Strided/Arith/Internal/Lists.hs4
2 files changed, 33 insertions, 1 deletions
diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs
index 5802573..6364802 100644
--- a/ops/Data/Array/Strided/Arith/Internal.hs
+++ b/ops/Data/Array/Strided/Arith/Internal.hs
@@ -714,6 +714,36 @@ class NumElt a where
numEltMaxIndex :: SNat n -> Array n a -> [Int]
numEltDotprodInner :: SNat n -> Array (n + 1) a -> Array (n + 1) a -> Array n a
+instance NumElt Int8 where
+ numEltAdd = addVectorInt8
+ numEltSub = subVectorInt8
+ numEltMul = mulVectorInt8
+ numEltNeg = negVectorInt8
+ numEltAbs = absVectorInt8
+ numEltSignum = signumVectorInt8
+ numEltSum1Inner = sum1VectorInt8
+ numEltProduct1Inner = product1VectorInt8
+ numEltSumFull = sumFullVectorInt8
+ numEltProductFull = productFullVectorInt8
+ numEltMinIndex _ = minindexVectorInt8
+ numEltMaxIndex _ = maxindexVectorInt8
+ numEltDotprodInner = dotprodinnerVectorInt8
+
+instance NumElt Int16 where
+ numEltAdd = addVectorInt16
+ numEltSub = subVectorInt16
+ numEltMul = mulVectorInt16
+ numEltNeg = negVectorInt16
+ numEltAbs = absVectorInt16
+ numEltSignum = signumVectorInt16
+ numEltSum1Inner = sum1VectorInt16
+ numEltProduct1Inner = product1VectorInt16
+ numEltSumFull = sumFullVectorInt16
+ numEltProductFull = productFullVectorInt16
+ numEltMinIndex _ = minindexVectorInt16
+ numEltMaxIndex _ = maxindexVectorInt16
+ numEltDotprodInner = dotprodinnerVectorInt16
+
instance NumElt Int32 where
numEltAdd = addVectorInt32
numEltSub = subVectorInt32
diff --git a/ops/Data/Array/Strided/Arith/Internal/Lists.hs b/ops/Data/Array/Strided/Arith/Internal/Lists.hs
index 910a77c..27204d2 100644
--- a/ops/Data/Array/Strided/Arith/Internal/Lists.hs
+++ b/ops/Data/Array/Strided/Arith/Internal/Lists.hs
@@ -16,7 +16,9 @@ data ArithType = ArithType
intTypesList :: [ArithType]
intTypesList =
- [ArithType ''Int32 "i32"
+ [ArithType ''Int8 "i8"
+ ,ArithType ''Int16 "i16"
+ ,ArithType ''Int32 "i32"
,ArithType ''Int64 "i64"
]