aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cbits/arith.c26
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs44
-rw-r--r--src/Data/Array/Mixed/Internal/Arith/Foreign.hs1
3 files changed, 35 insertions, 36 deletions
diff --git a/cbits/arith.c b/cbits/arith.c
index f08e456..6ea197d 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -113,7 +113,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
// Provides idx, outlinidx, arrlinidx.
#define TARRAY_WALK_NOINNER(again_label_name, rank, shape, strides, body) \
do { \
- i64 idx[(rank) - 1]; \
+ i64 idx[(rank) /* - 1 */]; /* Note: [zero-length VLA] */ \
memset(idx, 0, ((rank) - 1) * sizeof(idx[0])); \
i64 arrlinidx = 0; \
i64 outlinidx = 0; \
@@ -138,7 +138,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
// Provides idx, outlinidx, arrlinidx1, arrlinidx2.
#define TARRAY_WALK2_NOINNER(again_label_name, rank, shape, strides1, strides2, body) \
do { \
- i64 idx[(rank) - 1]; \
+ i64 idx[(rank) /* - 1 */]; /* Note: [zero-length VLA] */ \
memset(idx, 0, ((rank) - 1) * sizeof(idx[0])); \
i64 arrlinidx1 = 0, arrlinidx2 = 0; \
i64 outlinidx = 0; \
@@ -433,16 +433,6 @@ enum unop_tag_t {
#define LIST_UNOP(name, id, _)
};
-#define ENTRY_UNARY_OPS(typ) \
- void oxarop_unary_ ## typ(enum unop_tag_t tag, i64 n, typ *out, const typ *x) { \
- switch (tag) { \
- case UO_NEG: oxarop_op_neg_ ## typ(n, out, x); break; \
- case UO_ABS: oxarop_op_abs_ ## typ(n, out, x); break; \
- case UO_SIGNUM: oxarop_op_signum_ ## typ(n, out, x); break; \
- default: wrong_op("unary", tag); \
- } \
- }
-
#define ENTRY_UNARY_STRIDED_OPS(typ) \
void oxarop_unary_ ## typ ## _strided(enum unop_tag_t tag, i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *x) { \
switch (tag) { \
@@ -526,9 +516,6 @@ enum redop_tag_t {
COMM_OP(add, +, typ) \
NONCOMM_OP(sub, -, typ) \
COMM_OP(mul, *, typ) \
- UNARY_OP(neg, -, typ) \
- UNARY_OP(abs, GEN_ABS, typ) \
- UNARY_OP(signum, GEN_SIGNUM, typ) \
UNARY_OP_STRIDED(neg, -, typ) \
UNARY_OP_STRIDED(abs, GEN_ABS, typ) \
UNARY_OP_STRIDED(signum, GEN_SIGNUM, typ) \
@@ -537,7 +524,6 @@ enum redop_tag_t {
REDUCEFULL_OP(sumfull, +, typ) \
REDUCEFULL_OP(productfull, *, typ) \
ENTRY_BINARY_OPS(typ) \
- ENTRY_UNARY_OPS(typ) \
ENTRY_UNARY_STRIDED_OPS(typ) \
ENTRY_REDUCE1_OPS(typ) \
ENTRY_REDUCEFULL_OPS(typ) \
@@ -552,6 +538,7 @@ NUM_TYPES_XLIST
NONCOMM_OP(fdiv, /, typ) \
PREFIX_BINOP(pow, GEN_POW, typ) \
PREFIX_BINOP(logbase, GEN_LOGBASE, typ) \
+ /* TODO: when replaced with UNARY_OP_STRIDED, remove UNARY_OP entirely */ \
UNARY_OP(recip, 1.0/, typ) \
UNARY_OP(exp, GEN_EXP, typ) \
UNARY_OP(log, GEN_LOG, typ) \
@@ -576,3 +563,10 @@ NUM_TYPES_XLIST
ENTRY_FUNARY_OPS(typ)
FLOAT_TYPES_XLIST
#undef X
+
+// Note: [zero-length VLA]
+//
+// Zero-length variable-length arrays are not allowed in C(99). Thus whenever we
+// have a VLA that could sometimes suffice to be empty (e.g. `idx` in the
+// TARRAY_WALK_NOINNER macros), we tweak the length formula (typically by just
+// adding 1) so that it never ends up empty.
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index 123a4b5..58108f2 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -52,20 +52,27 @@ liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
{-# NOINLINE liftOpEltwise1 #-}
liftOpEltwise1 :: (Storable a, Storable b)
=> SNat n
- -> (VS.Vector a -> VS.Vector b)
- -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr a -> IO ())
+ -> (Ptr a -> Ptr a')
+ -> (Ptr b -> Ptr b')
+ -> (Int64 -> Ptr b' -> Ptr Int64 -> Ptr Int64 -> Ptr a' -> IO ())
-> RS.Array n a -> RS.Array n b
-liftOpEltwise1 sn@SNat f_vec cf_strided (RS.A (RG.A sh (OI.T strides offset vec)))
- | Just (blockOff, blockSz) <- stridesDense sh offset strides =
- let vec' = f_vec (VS.slice blockOff blockSz vec)
- in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec'))
+liftOpEltwise1 sn@SNat ptrconv1 ptrconv2 cf_strided (RS.A (RG.A sh (OI.T strides offset vec)))
+ -- TODO: less code duplication between these two branches
+ | Just (blockOff, blockSz) <- stridesDense sh offset strides = unsafePerformIO $ do
+ outv <- VSM.unsafeNew blockSz
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.singleton (fromIntegral blockSz)) $ \psh ->
+ VS.unsafeWith (VS.singleton 1) $ \pstrides ->
+ VS.unsafeWith (VS.slice blockOff blockSz vec) $ \pv ->
+ cf_strided 1 (ptrconv2 poutv) psh pstrides (ptrconv1 pv)
+ RS.A . RG.A sh . OI.T strides (offset - blockOff) <$> VS.unsafeFreeze outv
| otherwise = unsafePerformIO $ do
outv <- VSM.unsafeNew (product sh)
VSM.unsafeWith outv $ \poutv ->
VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh ->
VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides)) $ \pstrides ->
VS.unsafeWith vec $ \pv ->
- cf_strided (fromIntegral (fromSNat sn)) poutv psh pstrides pv
+ cf_strided (fromIntegral (fromSNat sn)) (ptrconv2 poutv) psh pstrides (ptrconv1 pv)
RS.fromVector sh <$> VS.unsafeFreeze outv
-- TODO: test all the cases of this thing with various input strides
@@ -440,11 +447,10 @@ $(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
fmap concat . forM [minBound..maxBound] $ \arithop -> do
let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype))
- c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop)))
c_op_strided = varE (mkName ("c_unary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (auoEnum arithop)))
sequence [SigD name <$>
[t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> liftOpEltwise1 sn (vectorOp1 id $c_op) $c_op_strided |]
+ ,do body <- [| \sn -> liftOpEltwise1 sn id id $c_op_strided |]
return $ FunD name [Clause [] (NormalB body) []]])
$(fmap concat . forM floatTypesList $ \arithtype -> do
@@ -506,12 +512,12 @@ $(fmap concat . forM typesList $ \arithtype -> do
-- This branch is ostensibly a runtime branch, but will (hopefully) be
-- constant-folded away by GHC.
intWidBranch1 :: forall i n. (FiniteBits i, Storable i)
- => (Int64 -> Ptr Int32 -> Ptr Int32 -> IO ())
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())
+ => (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ())
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())
-> (SNat n -> RS.Array n i -> RS.Array n i)
intWidBranch1 f32 f64 sn
- | finiteBitSize (undefined :: i) == 32 = liftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32)
- | finiteBitSize (undefined :: i) == 64 = liftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64)
+ | finiteBitSize (undefined :: i) == 32 = liftOpEltwise1 sn castPtr castPtr f32
+ | finiteBitSize (undefined :: i) == 64 = liftOpEltwise1 sn castPtr castPtr f64
| otherwise = error "Unsupported Int width"
intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i)
@@ -666,9 +672,9 @@ instance NumElt Int where
numEltMul = intWidBranch2 @Int (*)
(c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL))
(c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL))
- numEltNeg = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG))
- numEltAbs = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS))
- numEltSignum = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM))
+ numEltNeg = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG))
+ numEltAbs = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS))
+ numEltSignum = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM))
numEltSum1Inner = intWidBranchRed1 @Int
(c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM))
(c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM))
@@ -692,9 +698,9 @@ instance NumElt CInt where
numEltMul = intWidBranch2 @CInt (*)
(c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL))
(c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL))
- numEltNeg = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG))
- numEltAbs = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS))
- numEltSignum = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM))
+ numEltNeg = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG))
+ numEltAbs = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS))
+ numEltSignum = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM))
numEltSum1Inner = intWidBranchRed1 @CInt
(c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM))
(c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM))
diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
index 22c5b53..b53eb36 100644
--- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
@@ -15,7 +15,6 @@ $(do
[("binary_" ++ tyn ++ "_vv", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
,("binary_" ++ tyn ++ "_sv", [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
,("binary_" ++ tyn ++ "_vs", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
- ,("unary_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
,("unary_" ++ tyn ++ "_strided", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
,("reduce1_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
,("reducefull_" ++ tyn, [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |])