aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cbits/arith.c114
-rw-r--r--cbits/arith_lists.h10
-rw-r--r--ox-arrays.cabal1
-rw-r--r--src/Data/Array/Nested/Internal/Arith.hs77
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Foreign.hs35
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Lists.hs58
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Lists/TH.hs78
7 files changed, 287 insertions, 86 deletions
diff --git a/cbits/arith.c b/cbits/arith.c
index 002910c..65cdb41 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -1,28 +1,42 @@
+#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <stdbool.h>
#include <string.h>
#include <math.h>
+// These are the wrapper macros used in arith_lists.h. Preset them to empty to
+// avoid having to touch macros unrelated to the particular operation set below.
+#define LIST_BINOP(name, id, hsop)
+#define LIST_UNOP(name, id, _)
+#define LIST_REDOP(name, id, _)
+
+
+// Shorter names, due to CPP used both in function names and in C types.
typedef int32_t i32;
typedef int64_t i64;
+
+/*****************************************************************************
+ * Kernel functions *
+ *****************************************************************************/
+
#define COMM_OP(name, op, typ) \
- void oxarop_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \
+ static void oxarop_op_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \
for (i64 i = 0; i < n; i++) out[i] = x op y[i]; \
} \
- void oxarop_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \
+ static void oxarop_op_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \
for (i64 i = 0; i < n; i++) out[i] = x[i] op y[i]; \
}
#define NONCOMM_OP(name, op, typ) \
COMM_OP(name, op, typ) \
- void oxarop_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \
+ static void oxarop_op_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \
for (i64 i = 0; i < n; i++) out[i] = x[i] op y; \
}
#define UNARY_OP(name, op, typ) \
- void oxarop_ ## name ## _ ## typ(i64 n, typ *out, const typ *x) { \
+ static void oxarop_op_ ## name ## _ ## typ(i64 n, typ *out, const typ *x) { \
for (i64 i = 0; i < n; i++) out[i] = op(x[i]); \
}
@@ -68,7 +82,7 @@ typedef int64_t i64;
// Reduces along the innermost dimension.
// 'out' will be filled densely in linearisation order.
#define REDUCE1_OP(name, op, typ) \
- void oxarop_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \
+ static void oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \
if (strides[rank - 1] == 1) { \
TARRAY_WALK_NOINNER(again1, rank, shape, strides, { \
typ accum = arr[arrlinidx]; \
@@ -88,6 +102,91 @@ typedef int64_t i64;
} \
}
+
+/*****************************************************************************
+ * Entry point functions *
+ *****************************************************************************/
+
+__attribute__((noreturn, cold))
+static void wrong_op(const char *name, int tag) {
+ fprintf(stderr, "ox-arrays: Invalid operation tag passed to %s C code: %d\n", name, tag);
+ abort();
+}
+
+enum binop_tag_t {
+#undef LIST_BINOP
+#define LIST_BINOP(name, id, hsop) name = id,
+#include "arith_lists.h"
+#undef LIST_BINOP
+#define LIST_BINOP(name, id, hsop)
+};
+
+#define ENTRY_BINARY_OPS(typ) \
+ void oxarop_binary_ ## typ ## _sv(enum binop_tag_t tag, i64 n, typ *out, typ x, const typ *y) { \
+ switch (tag) { \
+ case BO_ADD: oxarop_op_add_ ## typ ## _sv(n, out, x, y); break; \
+ case BO_SUB: oxarop_op_sub_ ## typ ## _sv(n, out, x, y); break; \
+ case BO_MUL: oxarop_op_mul_ ## typ ## _sv(n, out, x, y); break; \
+ default: wrong_op("binary_sv", tag); \
+ } \
+ } \
+ void oxarop_binary_ ## typ ## _vs(enum binop_tag_t tag, i64 n, typ *out, const typ *x, typ y) { \
+ switch (tag) { \
+ case BO_ADD: oxarop_op_add_ ## typ ## _sv(n, out, y, x); break; \
+ case BO_SUB: oxarop_op_sub_ ## typ ## _vs(n, out, x, y); break; \
+ case BO_MUL: oxarop_op_mul_ ## typ ## _sv(n, out, y, x); break; \
+ default: wrong_op("binary_vs", tag); \
+ } \
+ } \
+ void oxarop_binary_ ## typ ## _vv(enum binop_tag_t tag, i64 n, typ *out, const typ *x, const typ *y) { \
+ switch (tag) { \
+ case BO_ADD: oxarop_op_add_ ## typ ## _vv(n, out, x, y); break; \
+ case BO_SUB: oxarop_op_sub_ ## typ ## _vv(n, out, x, y); break; \
+ case BO_MUL: oxarop_op_mul_ ## typ ## _vv(n, out, x, y); break; \
+ default: wrong_op("binary_vv", tag); \
+ } \
+ }
+
+enum unop_tag_t {
+#undef LIST_UNOP
+#define LIST_UNOP(name, id, _) name = id,
+#include "arith_lists.h"
+#undef LIST_UNOP
+#define LIST_UNOP(name, id, _)
+};
+
+#define ENTRY_UNARY_OPS(typ) \
+ void oxarop_unary_ ## typ(enum unop_tag_t tag, i64 n, typ *out, const typ *x) { \
+ switch (tag) { \
+ case UO_NEG: oxarop_op_neg_ ## typ(n, out, x); break; \
+ case UO_ABS: oxarop_op_abs_ ## typ(n, out, x); break; \
+ case UO_SIGNUM: oxarop_op_signum_ ## typ(n, out, x); break; \
+ default: wrong_op("unary", tag); \
+ } \
+ }
+
+enum redop_tag_t {
+#undef LIST_REDOP
+#define LIST_REDOP(name, id, _) name = id,
+#include "arith_lists.h"
+#undef LIST_REDOP
+#define LIST_REDOP(name, id, _)
+};
+
+#define ENTRY_REDUCE_OPS(typ) \
+ void oxarop_reduce_ ## typ(enum redop_tag_t tag, i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \
+ switch (tag) { \
+ case RO_SUM1: oxarop_op_sum1_ ## typ(rank, shape, strides, out, arr); break; \
+ case RO_PRODUCT1: oxarop_op_product1_ ## typ(rank, shape, strides, out, arr); break; \
+ default: wrong_op("reduce", tag); \
+ } \
+ }
+
+
+/*****************************************************************************
+ * Generate all the functions *
+ *****************************************************************************/
+
#define NUM_TYPES_LOOP_XLIST \
X(i32) X(i64) X(double) X(float)
@@ -99,6 +198,9 @@ typedef int64_t i64;
UNARY_OP(abs, GEN_ABS, typ) \
UNARY_OP(signum, GEN_SIGNUM, typ) \
REDUCE1_OP(sum1, +, typ) \
- REDUCE1_OP(product1, *, typ)
+ REDUCE1_OP(product1, *, typ) \
+ ENTRY_BINARY_OPS(typ) \
+ ENTRY_UNARY_OPS(typ) \
+ ENTRY_REDUCE_OPS(typ)
NUM_TYPES_LOOP_XLIST
#undef X
diff --git a/cbits/arith_lists.h b/cbits/arith_lists.h
new file mode 100644
index 0000000..c7495e8
--- /dev/null
+++ b/cbits/arith_lists.h
@@ -0,0 +1,10 @@
+LIST_BINOP(BO_ADD, 1, +)
+LIST_BINOP(BO_SUB, 2, -)
+LIST_BINOP(BO_MUL, 3, *)
+
+LIST_UNOP(UO_NEG, 1,)
+LIST_UNOP(UO_ABS, 2,)
+LIST_UNOP(UO_SIGNUM, 3,)
+
+LIST_REDOP(RO_SUM1, 1,)
+LIST_REDOP(RO_PRODUCT1, 2,)
diff --git a/ox-arrays.cabal b/ox-arrays.cabal
index af985f4..875c54e 100644
--- a/ox-arrays.cabal
+++ b/ox-arrays.cabal
@@ -13,6 +13,7 @@ library
Data.Array.Nested.Internal.Arith
Data.Array.Nested.Internal.Arith.Foreign
Data.Array.Nested.Internal.Arith.Lists
+ Data.Array.Nested.Internal.Arith.Lists.TH
build-depends:
base >=4.18 && <4.20,
deepseq,
diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs
index 4bfc043..7484455 100644
--- a/src/Data/Array/Nested/Internal/Arith.hs
+++ b/src/Data/Array/Nested/Internal/Arith.hs
@@ -182,14 +182,13 @@ class NumElt a where
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
- fmap concat . forM binopsList $ \arithop -> do
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype))
- cnamebase = "c_" ++ aboName arithop ++ "_" ++ atCName arithtype
- c_ss = varE (aboScalFun arithop arithtype)
- c_sv = varE $ mkName (cnamebase ++ "_sv")
- c_vs | aboComm arithop == NonComm = varE $ mkName (cnamebase ++ "_vs")
- | otherwise = [| flipOp $c_sv |]
- c_vv = varE $ mkName (cnamebase ++ "_vv")
+ cnamebase = "c_binary_" ++ atCName arithtype
+ c_ss = varE (aboNumOp arithop)
+ c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
+ c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
+ c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
sequence [SigD name <$>
[t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]
@@ -197,9 +196,9 @@ $(fmap concat . forM typesList $ \arithtype -> do
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
- fmap concat . forM unopsList $ \arithop -> do
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype))
- c_op = varE $ mkName ("c_" ++ auoName arithop ++ "_" ++ atCName arithtype)
+ c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop)))
sequence [SigD name <$>
[t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]
,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]
@@ -207,10 +206,10 @@ $(fmap concat . forM typesList $ \arithtype -> do
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
- fmap concat . forM redopsList $ \redop -> do
- let name = mkName (aroName redop ++ "Vector" ++ nameBase (atType arithtype))
- c_op = varE $ mkName ("c_" ++ aroName redop ++ "_" ++ atCName arithtype)
- c_scale_op = varE $ mkName ("c_mul_" ++ atCName arithtype ++ "_sv")
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (aroName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ c_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))
+ c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))
sequence [SigD name <$>
[t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |]
,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |]
@@ -297,21 +296,41 @@ instance NumElt Double where
numEltProduct1Inner = product1VectorDouble
instance NumElt Int where
- numEltAdd = intWidBranch2 @Int (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv
- numEltSub = intWidBranch2 @Int (-) c_sub_i32_sv (flipOp c_sub_i32_sv) c_sub_i32_vv c_sub_i64_sv (flipOp c_sub_i64_sv) c_sub_i64_vv
- numEltMul = intWidBranch2 @Int (*) c_mul_i32_sv (flipOp c_mul_i32_sv) c_mul_i32_vv c_mul_i64_sv (flipOp c_mul_i64_sv) c_mul_i64_vv
- numEltNeg = intWidBranch1 @Int c_neg_i32 c_neg_i64
- numEltAbs = intWidBranch1 @Int c_abs_i32 c_abs_i64
- numEltSignum = intWidBranch1 @Int c_signum_i32 c_signum_i64
- numEltSum1Inner = intWidBranchRed @Int c_mul_i32_sv c_sum1_i32 c_mul_i64_sv c_sum1_i64
- numEltProduct1Inner = intWidBranchRed @Int c_mul_i32_sv c_product1_i32 c_mul_i64_sv c_product1_i64
+ numEltAdd = intWidBranch2 @Int (+)
+ (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD))
+ (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD))
+ numEltSub = intWidBranch2 @Int (-)
+ (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB))
+ (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB))
+ numEltMul = intWidBranch2 @Int (*)
+ (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL))
+ numEltNeg = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG))
+ numEltAbs = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS))
+ numEltSignum = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM))
+ numEltSum1Inner = intWidBranchRed @Int
+ (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1))
+ numEltProduct1Inner = intWidBranchRed @Int
+ (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
instance NumElt CInt where
- numEltAdd = intWidBranch2 @CInt (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv
- numEltSub = intWidBranch2 @CInt (-) c_sub_i32_sv (flipOp c_sub_i32_sv) c_sub_i32_vv c_sub_i64_sv (flipOp c_sub_i64_sv) c_sub_i64_vv
- numEltMul = intWidBranch2 @CInt (*) c_mul_i32_sv (flipOp c_mul_i32_sv) c_mul_i32_vv c_mul_i64_sv (flipOp c_mul_i64_sv) c_mul_i64_vv
- numEltNeg = intWidBranch1 @CInt c_neg_i32 c_neg_i64
- numEltAbs = intWidBranch1 @CInt c_abs_i32 c_abs_i64
- numEltSignum = intWidBranch1 @CInt c_signum_i32 c_signum_i64
- numEltSum1Inner = intWidBranchRed @CInt c_mul_i32_sv c_sum1_i32 c_mul_i64_sv c_sum1_i64
- numEltProduct1Inner = intWidBranchRed @CInt c_mul_i32_sv c_product1_i32 c_mul_i64_sv c_product1_i64
+ numEltAdd = intWidBranch2 @CInt (+)
+ (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD))
+ (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD))
+ numEltSub = intWidBranch2 @CInt (-)
+ (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB))
+ (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB))
+ numEltMul = intWidBranch2 @CInt (*)
+ (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL))
+ numEltNeg = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG))
+ numEltAbs = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS))
+ numEltSignum = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM))
+ numEltSum1Inner = intWidBranchRed @CInt
+ (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1))
+ numEltProduct1Inner = intWidBranchRed @CInt
+ (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs
index f84b1c5..49effa1 100644
--- a/src/Data/Array/Nested/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Nested/Internal/Arith/Foreign.hs
@@ -5,6 +5,7 @@ module Data.Array.Nested.Internal.Arith.Foreign where
import Control.Monad
import Data.Int
import Data.Maybe
+import Foreign.C.Types
import Foreign.Ptr
import Language.Haskell.TH
@@ -13,28 +14,24 @@ import Data.Array.Nested.Internal.Arith.Lists
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
- fmap concat . forM binopsList $ \arithop -> do
- let base = aboName arithop ++ "_" ++ atCName arithtype
- sequence $ catMaybes
- [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$>
- [t| Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
- ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$>
- [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
- ,guard (aboComm arithop == NonComm) >>
- Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$>
- [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
- ])
+ let base = "binary_" ++ atCName arithtype
+ sequence $ catMaybes
+ [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
+ ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+ ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
+ ])
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
- forM unopsList $ \arithop -> do
- let base = auoName arithop ++ "_" ++ atCName arithtype
- ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
- [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+ let base = "unary_" ++ atCName arithtype
+ pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
- forM redopsList $ \redop -> do
- let base = aroName redop ++ "_" ++ atCName arithtype
- ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
- [t| Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+ let base = "reduce_" ++ atCName arithtype
+ pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
+ [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Nested/Internal/Arith/Lists.hs
index 78fe24a..91e50ad 100644
--- a/src/Data/Array/Nested/Internal/Arith/Lists.hs
+++ b/src/Data/Array/Nested/Internal/Arith/Lists.hs
@@ -1,13 +1,13 @@
-{-# LANGUAGE TemplateHaskellQuotes #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TemplateHaskell #-}
module Data.Array.Nested.Internal.Arith.Lists where
+import Data.Char
import Data.Int
-
import Language.Haskell.TH
+import Data.Array.Nested.Internal.Arith.Lists.TH
-data Commutative = Comm | NonComm
- deriving (Show, Eq)
data ArithType = ArithType
{ atType :: Name -- ''Int32
@@ -22,36 +22,30 @@ typesList =
,ArithType ''Double "double"
]
-data ArithBOp = ArithBOp
- { aboName :: String -- "add"
- , aboComm :: Commutative -- Comm
- , aboScalFun :: ArithType -> Name -- \_ -> '(+)
- }
+-- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded)
+$(genArithDataType Binop "ArithBOp")
-binopsList :: [ArithBOp]
-binopsList =
- [ArithBOp "add" Comm (\_ -> '(+))
- ,ArithBOp "sub" NonComm (\_ -> '(-))
- ,ArithBOp "mul" Comm (\_ -> '(*))
- ]
+$(genArithNameFun Binop ''ArithBOp "aboName" (map toLower . drop 3))
+$(genArithEnumFun Binop ''ArithBOp "aboEnum")
-data ArithUOp = ArithUOp
- { auoName :: String -- "neg"
- }
+$(do clauses <- readArithLists Binop
+ (\name _num hsop -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (VarE 'mkName `AppE` LitE (StringL hsop)))
+ []))
+ return
+ sequence [SigD (mkName "aboNumOp") <$> [t| ArithBOp -> Name |]
+ ,return $ FunD (mkName "aboNumOp") clauses])
-unopsList :: [ArithUOp]
-unopsList =
- [ArithUOp "neg"
- ,ArithUOp "abs"
- ,ArithUOp "signum"
- ]
-data ArithRedOp = ArithRedOp
- { aroName :: String -- "sum"
- }
+-- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded)
+$(genArithDataType Unop "ArithUOp")
-redopsList :: [ArithRedOp]
-redopsList =
- [ArithRedOp "sum1"
- ,ArithRedOp "product1"
- ]
+$(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3))
+$(genArithEnumFun Unop ''ArithUOp "auoEnum")
+
+
+-- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded)
+$(genArithDataType Redop "ArithRedOp")
+
+$(genArithNameFun Redop ''ArithRedOp "aroName" (map toLower . drop 3))
+$(genArithEnumFun Redop ''ArithRedOp "aroEnum")
diff --git a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs b/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs
new file mode 100644
index 0000000..b748b97
--- /dev/null
+++ b/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs
@@ -0,0 +1,78 @@
+{-# LANGUAGE TemplateHaskellQuotes #-}
+module Data.Array.Nested.Internal.Arith.Lists.TH where
+
+import Control.Monad
+import Control.Monad.IO.Class
+import Data.Maybe
+import Foreign.C.Types
+import Language.Haskell.TH
+import Text.Read
+
+
+data OpKind = Binop | Unop | Redop
+ deriving (Show, Eq)
+
+readArithLists :: OpKind
+ -> (String -> Int -> String -> Q a)
+ -> ([a] -> Q r)
+ -> Q r
+readArithLists targetkind fop fcombine = do
+ lns <- liftIO $ lines <$> readFile "cbits/arith_lists.h"
+
+ mvals <- forM lns $ \line -> do
+ if null (dropWhile (== ' ') line)
+ then return Nothing
+ else do let (kind, name, num, aux) = parseLine line
+ if kind == targetkind
+ then Just <$> fop name num aux
+ else return Nothing
+
+ fcombine (catMaybes mvals)
+ where
+ parseLine s0
+ | ("LIST_", s1) <- splitAt 5 s0
+ , (kindstr, '(' : s2) <- break (== '(') s1
+ , (f1, ',' : s3) <- parseField s2
+ , (f2, ',' : s4) <- parseField s3
+ , (f3, ')' : _) <- parseField s4
+ , Just kind <- parseKind kindstr
+ , let name = f1
+ , Just num <- readMaybe f2
+ , let aux = f3
+ = (kind, name, num, aux)
+ | otherwise
+ = error $ "readArithLists: unrecognised line in cbits/arith_lists.h: " ++ show s0
+
+ parseField s = break (`elem` ",)") (dropWhile (== ' ') s)
+
+ parseKind "BINOP" = Just Binop
+ parseKind "UNOP" = Just Unop
+ parseKind "REDOP" = Just Redop
+ parseKind _ = Nothing
+
+genArithDataType :: OpKind -> String -> Q [Dec]
+genArithDataType kind dtname = do
+ cons <- readArithLists kind
+ (\name _num _ -> return $ NormalC (mkName name) [])
+ return
+ return [DataD [] (mkName dtname) [] Nothing cons [DerivClause Nothing [ConT ''Show, ConT ''Enum, ConT ''Bounded]]]
+
+genArithNameFun :: OpKind -> Name -> String -> (String -> String) -> Q [Dec]
+genArithNameFun kind dtname funname nametrans = do
+ clauses <- readArithLists kind
+ (\name _num _ -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (LitE (StringL (nametrans name))))
+ []))
+ return
+ return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''String)
+ ,FunD (mkName funname) clauses]
+
+genArithEnumFun :: OpKind -> Name -> String -> Q [Dec]
+genArithEnumFun kind dtname funname = do
+ clauses <- readArithLists kind
+ (\name num _ -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (LitE (IntegerL (fromIntegral num))))
+ []))
+ return
+ return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''CInt)
+ ,FunD (mkName funname) clauses]