diff options
Diffstat (limited to 'candle-kernels/src/binary_op_macros.cuh')
-rw-r--r-- | candle-kernels/src/binary_op_macros.cuh | 24 |
1 files changed, 14 insertions, 10 deletions
diff --git a/candle-kernels/src/binary_op_macros.cuh b/candle-kernels/src/binary_op_macros.cuh index 219ee09c..05d0c3df 100644 --- a/candle-kernels/src/binary_op_macros.cuh +++ b/candle-kernels/src/binary_op_macros.cuh @@ -1,13 +1,13 @@ #include "cuda_utils.cuh" -#define BINARY_OP(TYPENAME, FN_NAME, FUNC) \ +#define BINARY_OP_OUT(TYPENAME, OUT_TYPENAME, FN_NAME, FUNC) \ extern "C" __global__ void FN_NAME( \ const size_t numel, \ const size_t num_dims, \ const size_t *dims_and_strides, \ const TYPENAME *lhs, \ const TYPENAME *rhs, \ - TYPENAME *out \ + OUT_TYPENAME *out \ ) { \ const size_t *dims = dims_and_strides; \ const size_t *lhs_strides = dims_and_strides + 1 * num_dims; \ @@ -16,8 +16,8 @@ extern "C" __global__ void FN_NAME( \ bool rhs_cont = is_contiguous(num_dims, dims, rhs_strides); \ if (lhs_cont && rhs_cont) { \ for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ - TYPENAME x = lhs ? lhs[i] : out[i]; \ - TYPENAME y = rhs ? rhs[i] : out[i]; \ + TYPENAME x = lhs[i]; \ + TYPENAME y = rhs[i]; \ out[i] = FUNC; \ } \ } else if (lhs_cont) { \ @@ -29,8 +29,8 @@ extern "C" __global__ void FN_NAME( \ rhs_i += i_dim * rhs_strides[d]; \ tmp_i /= dims[d]; \ } \ - TYPENAME x = lhs ? lhs[i] : out[i]; \ - TYPENAME y = rhs ? rhs[rhs_i] : out[i]; \ + TYPENAME x = lhs[i]; \ + TYPENAME y = rhs[rhs_i]; \ out[i] = FUNC; \ } \ } else if (rhs_cont) { \ @@ -42,8 +42,8 @@ extern "C" __global__ void FN_NAME( \ lhs_i += i_dim * lhs_strides[d]; \ tmp_i /= dims[d]; \ } \ - TYPENAME x = lhs ? lhs[lhs_i] : out[i]; \ - TYPENAME y = rhs ? rhs[i] : out[i]; \ + TYPENAME x = lhs[lhs_i]; \ + TYPENAME y = rhs[i]; \ out[i] = FUNC; \ } \ } else { \ @@ -57,9 +57,13 @@ extern "C" __global__ void FN_NAME( \ rhs_i += i_dim * rhs_strides[d]; \ tmp_i /= dims[d]; \ } \ - TYPENAME x = lhs ? lhs[lhs_i] : out[i]; \ - TYPENAME y = rhs ? rhs[rhs_i] : out[i]; \ + TYPENAME x = lhs[lhs_i]; \ + TYPENAME y = rhs[rhs_i]; \ out[i] = FUNC; \ } \ } \ } \ + + +#define BINARY_OP(TYPENAME, FN_NAME, FUNC) \ + BINARY_OP_OUT(TYPENAME, TYPENAME, FN_NAME, FUNC) |