diff options
Diffstat (limited to 'candle-kernels/src')
-rw-r--r-- | candle-kernels/src/indexing.cu | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu index 0272a330..8fc69363 100644 --- a/candle-kernels/src/indexing.cu +++ b/candle-kernels/src/indexing.cu @@ -12,17 +12,18 @@ __device__ void index_select( const T *inp, T *out, const size_t left_size, - const size_t dim_size, + const size_t src_dim_size, + const size_t ids_dim_size, const size_t right_size ) { const size_t *dims = info; const size_t *strides = info + num_dims; bool b = is_contiguous(num_dims, dims, strides); for (unsigned int dst_i = blockIdx.x * blockDim.x + threadIdx.x; dst_i < numel; dst_i += blockDim.x * gridDim.x) { - unsigned int left_i = dst_i / (dim_size * right_size); - unsigned int id_i = dst_i / right_size % dim_size; + unsigned int left_i = dst_i / (ids_dim_size * right_size); + unsigned int id_i = dst_i / right_size % ids_dim_size; unsigned int right_i = dst_i % right_size; - unsigned int src_i = left_i * (dim_size * right_size) + ids[id_i] * right_size + right_i; + unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i; unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides); out[dst_i] = inp[strided_i]; } @@ -37,9 +38,10 @@ extern "C" __global__ void FN_NAME( \ const TYPENAME *inp, \ TYPENAME *out, \ const size_t left_size, \ - const size_t dim_size, \ + const size_t src_dim_size, \ + const size_t ids_dim_size, \ const size_t right_size \ -) { index_select(numel, num_dims, info, ids, inp, out, left_size, dim_size, right_size); } \ +) { index_select(numel, num_dims, info, ids, inp, out, left_size, src_dim_size, ids_dim_size, right_size); } \ template<typename T, typename I> __device__ void gather( |