diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-03-05 22:09:50 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-05 22:09:50 +0100 | 
| commit | 984e5315768dd190a97069167daf970c17c3c867 (patch) | |
| tree | 7db8a4173a4198ba0a3fafa54799fd7273f0bfdb | |
| parent | 37eec011de921504dc16fd16ec9bb0e5008347fd (diff) | |
arith: Unary float ops on strided arrays without normalisation
| -rw-r--r-- | cbits/arith.c | 92 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 4 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 2 | 
3 files changed, 46 insertions, 52 deletions
diff --git a/cbits/arith.c b/cbits/arith.c index 6ea197d..8d0700d 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -200,11 +200,6 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {      for (i64 i = 0; i < n; i++) out[i] = op(x[i], y); \    } -#define UNARY_OP(name, op, typ) \ -  static void oxarop_op_ ## name ## _ ## typ(i64 n, typ *out, const typ *x) { \ -    for (i64 i = 0; i < n; i++) out[i] = op(x[i]); \ -  } -  #define UNARY_OP_STRIDED(name, op, typ) \    static void oxarop_op_ ## name ## _ ## typ ## _strided(i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *arr) { \      /* fprintf(stderr, "oxarop_op_" #name "_" #typ "_strided: rank=%ld shape=", rank); \ @@ -451,29 +446,29 @@ enum funop_tag_t {  #define LIST_FUNOP(name, id, _)  }; -#define ENTRY_FUNARY_OPS(typ) \ -  void oxarop_funary_ ## typ(enum funop_tag_t tag, i64 n, typ *out, const typ *x) { \ +#define ENTRY_FUNARY_STRIDED_OPS(typ) \ +  void oxarop_funary_ ## typ ## _strided(enum funop_tag_t tag, i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *x) { \      switch (tag) { \ -      case FU_RECIP: oxarop_op_recip_ ## typ(n, out, x); break; \ -      case FU_EXP: oxarop_op_exp_ ## typ(n, out, x); break; \ -      case FU_LOG: oxarop_op_log_ ## typ(n, out, x); break; \ -      case FU_SQRT: oxarop_op_sqrt_ ## typ(n, out, x); break; \ -      case FU_SIN: oxarop_op_sin_ ## typ(n, out, x); break; \ -      case FU_COS: oxarop_op_cos_ ## typ(n, out, x); break; \ -      case FU_TAN: oxarop_op_tan_ ## typ(n, out, x); break; \ -      case FU_ASIN: oxarop_op_asin_ ## typ(n, out, x); break; \ -      case FU_ACOS: oxarop_op_acos_ ## typ(n, out, x); break; \ -      case FU_ATAN: oxarop_op_atan_ ## typ(n, out, x); break; \ -      case FU_SINH: oxarop_op_sinh_ ## typ(n, out, x); break; \ -      case FU_COSH: oxarop_op_cosh_ ## typ(n, out, x); break; \ -      case FU_TANH: oxarop_op_tanh_ ## typ(n, out, x); break; \ -      case FU_ASINH: oxarop_op_asinh_ ## typ(n, out, x); break; \ -      case FU_ACOSH: oxarop_op_acosh_ ## typ(n, out, x); break; \ -      case FU_ATANH: oxarop_op_atanh_ ## typ(n, out, x); break; \ -      case FU_LOG1P: oxarop_op_log1p_ ## typ(n, out, x); break; \ -      case FU_EXPM1: oxarop_op_expm1_ ## typ(n, out, x); break; \ -      case FU_LOG1PEXP: oxarop_op_log1pexp_ ## typ(n, out, x); break; \ -      case FU_LOG1MEXP: oxarop_op_log1mexp_ ## typ(n, out, x); break; \ +      case FU_RECIP:    oxarop_op_recip_    ## typ ## _strided(rank, out, shape, strides, x); break; \ +      case FU_EXP:      oxarop_op_exp_      ## typ ## _strided(rank, out, shape, strides, x); break; \ +      case FU_LOG:      oxarop_op_log_      ## typ ## _strided(rank, out, shape, strides, x); break; \ +      case FU_SQRT:     oxarop_op_sqrt_     ## typ ## _strided(rank, out, shape, strides, x); break; \ +      case FU_SIN:      oxarop_op_sin_      ## typ ## _strided(rank, out, shape, strides, x); break; \ +      case FU_COS:      oxarop_op_cos_      ## typ ## _strided(rank, out, shape, strides, x); break; \ +      case FU_TAN:      oxarop_op_tan_      ## typ ## _strided(rank, out, shape, strides, x); break; \ +      case FU_ASIN:     oxarop_op_asin_     ## typ ## _strided(rank, out, shape, strides, x); break; \ +      case FU_ACOS:     oxarop_op_acos_     ## typ ## _strided(rank, out, shape, strides, x); break; \ +      case FU_ATAN:     oxarop_op_atan_     ## typ ## _strided(rank, out, shape, strides, x); break; \ +      case FU_SINH:     oxarop_op_sinh_     ## typ ## _strided(rank, out, shape, strides, x); break; \ +      case FU_COSH:     oxarop_op_cosh_     ## typ ## _strided(rank, out, shape, strides, x); break; \ +      case FU_TANH:     oxarop_op_tanh_     ## typ ## _strided(rank, out, shape, strides, x); break; \ +      case FU_ASINH:    oxarop_op_asinh_    ## typ ## _strided(rank, out, shape, strides, x); break; \ +      case FU_ACOSH:    oxarop_op_acosh_    ## typ ## _strided(rank, out, shape, strides, x); break; \ +      case FU_ATANH:    oxarop_op_atanh_    ## typ ## _strided(rank, out, shape, strides, x); break; \ +      case FU_LOG1P:    oxarop_op_log1p_    ## typ ## _strided(rank, out, shape, strides, x); break; \ +      case FU_EXPM1:    oxarop_op_expm1_    ## typ ## _strided(rank, out, shape, strides, x); break; \ +      case FU_LOG1PEXP: oxarop_op_log1pexp_ ## typ ## _strided(rank, out, shape, strides, x); break; \ +      case FU_LOG1MEXP: oxarop_op_log1mexp_ ## typ ## _strided(rank, out, shape, strides, x); break; \        default: wrong_op("unary", tag); \      } \    } @@ -538,29 +533,28 @@ NUM_TYPES_XLIST    NONCOMM_OP(fdiv, /, typ) \    PREFIX_BINOP(pow, GEN_POW, typ) \    PREFIX_BINOP(logbase, GEN_LOGBASE, typ) \ -  /* TODO: when replaced with UNARY_OP_STRIDED, remove UNARY_OP entirely */ \ -  UNARY_OP(recip, 1.0/, typ) \ -  UNARY_OP(exp, GEN_EXP, typ) \ -  UNARY_OP(log, GEN_LOG, typ) \ -  UNARY_OP(sqrt, GEN_SQRT, typ) \ -  UNARY_OP(sin, GEN_SIN, typ) \ -  UNARY_OP(cos, GEN_COS, typ) \ -  UNARY_OP(tan, GEN_TAN, typ) \ -  UNARY_OP(asin, GEN_ASIN, typ) \ -  UNARY_OP(acos, GEN_ACOS, typ) \ -  UNARY_OP(atan, GEN_ATAN, typ) \ -  UNARY_OP(sinh, GEN_SINH, typ) \ -  UNARY_OP(cosh, GEN_COSH, typ) \ -  UNARY_OP(tanh, GEN_TANH, typ) \ -  UNARY_OP(asinh, GEN_ASINH, typ) \ -  UNARY_OP(acosh, GEN_ACOSH, typ) \ -  UNARY_OP(atanh, GEN_ATANH, typ) \ -  UNARY_OP(log1p, GEN_LOG1P, typ) \ -  UNARY_OP(expm1, GEN_EXPM1, typ) \ -  UNARY_OP(log1pexp, GEN_LOG1PEXP, typ) \ -  UNARY_OP(log1mexp, GEN_LOG1MEXP, typ) \ +  UNARY_OP_STRIDED(recip, 1.0/, typ) \ +  UNARY_OP_STRIDED(exp, GEN_EXP, typ) \ +  UNARY_OP_STRIDED(log, GEN_LOG, typ) \ +  UNARY_OP_STRIDED(sqrt, GEN_SQRT, typ) \ +  UNARY_OP_STRIDED(sin, GEN_SIN, typ) \ +  UNARY_OP_STRIDED(cos, GEN_COS, typ) \ +  UNARY_OP_STRIDED(tan, GEN_TAN, typ) \ +  UNARY_OP_STRIDED(asin, GEN_ASIN, typ) \ +  UNARY_OP_STRIDED(acos, GEN_ACOS, typ) \ +  UNARY_OP_STRIDED(atan, GEN_ATAN, typ) \ +  UNARY_OP_STRIDED(sinh, GEN_SINH, typ) \ +  UNARY_OP_STRIDED(cosh, GEN_COSH, typ) \ +  UNARY_OP_STRIDED(tanh, GEN_TANH, typ) \ +  UNARY_OP_STRIDED(asinh, GEN_ASINH, typ) \ +  UNARY_OP_STRIDED(acosh, GEN_ACOSH, typ) \ +  UNARY_OP_STRIDED(atanh, GEN_ATANH, typ) \ +  UNARY_OP_STRIDED(log1p, GEN_LOG1P, typ) \ +  UNARY_OP_STRIDED(expm1, GEN_EXPM1, typ) \ +  UNARY_OP_STRIDED(log1pexp, GEN_LOG1PEXP, typ) \ +  UNARY_OP_STRIDED(log1mexp, GEN_LOG1MEXP, typ) \    ENTRY_FBINARY_OPS(typ) \ -  ENTRY_FUNARY_OPS(typ) +  ENTRY_FUNARY_STRIDED_OPS(typ)  FLOAT_TYPES_XLIST  #undef X diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index a403d3c..11ee3fe 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -460,10 +460,10 @@ $(fmap concat . forM floatTypesList $ \arithtype -> do      let ttyp = conT (atType arithtype)      fmap concat . forM [minBound..maxBound] $ \arithop -> do        let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype)) -          c_op = varE (mkName ("c_funary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (afuoEnum arithop))) +          c_op_strided = varE (mkName ("c_funary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (afuoEnum arithop)))        sequence [SigD name <$>                       [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] -               ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] +               ,do body <- [| \sn -> liftOpEltwise1 sn id id $c_op_strided |]                     return $ FunD name [Clause [] (NormalB body) []]])  mulWithInt :: Num a => a -> Int -> a diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs index b53eb36..a60b717 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -29,7 +29,7 @@ $(do          [("fbinary_" ++ tyn ++ "_vv",      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])          ,("fbinary_" ++ tyn ++ "_sv",      [t| CInt -> Int64 -> Ptr $ttyp ->     $ttyp -> Ptr $ttyp -> IO () |])          ,("fbinary_" ++ tyn ++ "_vs",      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp ->     $ttyp -> IO () |]) -        ,("funary_" ++ tyn,                [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) +        ,("funary_" ++ tyn ++ "_strided",  [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])          ]    let generate types imports =  | 
