diff options
author | Gonzalo <456459+grzuy@users.noreply.github.com> | 2023-10-05 14:46:13 -0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-05 18:46:13 +0100 |
commit | 8f7973958c55324a24f0c514e7ac6ded6681980f (patch) | |
tree | 57db1dd7d15dd128481cec0c7a0012404e1361ba /candle-kernels/src | |
parent | f0c619a4af0810200c6749f63c2474962419a84e (diff) | |
download | candle-8f7973958c55324a24f0c514e7ac6ded6681980f.tar.gz candle-8f7973958c55324a24f0c514e7ac6ded6681980f.tar.bz2 candle-8f7973958c55324a24f0c514e7ac6ded6681980f.zip |
fix: fix index_select cuda kernel for src target dim different than ids dim when selecting dim > 0 (#1037)
* fix: fix index_select cuda kernel for src target dim different than ids dim when selecting dim > 0
* cargo fmt
Diffstat (limited to 'candle-kernels/src')
-rw-r--r-- | candle-kernels/src/indexing.cu | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu index 0272a330..8fc69363 100644 --- a/candle-kernels/src/indexing.cu +++ b/candle-kernels/src/indexing.cu @@ -12,17 +12,18 @@ __device__ void index_select( const T *inp, T *out, const size_t left_size, - const size_t dim_size, + const size_t src_dim_size, + const size_t ids_dim_size, const size_t right_size ) { const size_t *dims = info; const size_t *strides = info + num_dims; bool b = is_contiguous(num_dims, dims, strides); for (unsigned int dst_i = blockIdx.x * blockDim.x + threadIdx.x; dst_i < numel; dst_i += blockDim.x * gridDim.x) { - unsigned int left_i = dst_i / (dim_size * right_size); - unsigned int id_i = dst_i / right_size % dim_size; + unsigned int left_i = dst_i / (ids_dim_size * right_size); + unsigned int id_i = dst_i / right_size % ids_dim_size; unsigned int right_i = dst_i % right_size; - unsigned int src_i = left_i * (dim_size * right_size) + ids[id_i] * right_size + right_i; + unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i; unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides); out[dst_i] = inp[strided_i]; } @@ -37,9 +38,10 @@ extern "C" __global__ void FN_NAME( \ const TYPENAME *inp, \ TYPENAME *out, \ const size_t left_size, \ - const size_t dim_size, \ + const size_t src_dim_size, \ + const size_t ids_dim_size, \ const size_t right_size \ -) { index_select(numel, num_dims, info, ids, inp, out, left_size, dim_size, right_size); } \ +) { index_select(numel, num_dims, info, ids, inp, out, left_size, src_dim_size, ids_dim_size, right_size); } \ template<typename T, typename I> __device__ void gather( |