summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGonzalo <456459+grzuy@users.noreply.github.com>2023-10-05 14:46:13 -0300
committerGitHub <noreply@github.com>2023-10-05 18:46:13 +0100
commit8f7973958c55324a24f0c514e7ac6ded6681980f (patch)
tree57db1dd7d15dd128481cec0c7a0012404e1361ba
parentf0c619a4af0810200c6749f63c2474962419a84e (diff)
downloadcandle-8f7973958c55324a24f0c514e7ac6ded6681980f.tar.gz
candle-8f7973958c55324a24f0c514e7ac6ded6681980f.tar.bz2
candle-8f7973958c55324a24f0c514e7ac6ded6681980f.zip
fix: fix index_select cuda kernel for src target dim different than ids dim when selecting dim > 0 (#1037)
* fix: fix index_select cuda kernel for src target dim different than ids dim when selecting dim > 0 * cargo fmt
-rw-r--r--candle-core/src/cuda_backend.rs6
-rw-r--r--candle-core/tests/tensor_tests.rs9
-rw-r--r--candle-kernels/src/indexing.cu14
3 files changed, 21 insertions, 8 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 1599425f..f7518067 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -891,7 +891,8 @@ impl<'a> Map1 for IndexSelect<'a> {
};
let left_size: usize = src_l.dims()[..self.2].iter().product();
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
- let dim_size = ids_shape.elem_count();
+ let src_dim_size = src_l.dims()[self.2];
+ let ids_dim_size = ids_shape.elem_count();
let dst_el = ids_shape.elem_count() * left_size * right_size;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
@@ -905,7 +906,8 @@ impl<'a> Map1 for IndexSelect<'a> {
&src,
&out,
left_size,
- dim_size,
+ src_dim_size,
+ ids_dim_size,
right_size,
);
// SAFETY: ffi.
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index 2e867b26..a50f3a6c 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -680,6 +680,15 @@ fn index_select(device: &Device) -> Result<()> {
[3.0, 4.0, 5.0],
]
);
+
+ // Test when selecting dim > 0 with ids size different from elem count of
+ // target dim in source/input.
+ let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
+ let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
+ assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
+ let hs = t.index_select(&ids, 1)?;
+ assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
+
Ok(())
}
diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu
index 0272a330..8fc69363 100644
--- a/candle-kernels/src/indexing.cu
+++ b/candle-kernels/src/indexing.cu
@@ -12,17 +12,18 @@ __device__ void index_select(
const T *inp,
T *out,
const size_t left_size,
- const size_t dim_size,
+ const size_t src_dim_size,
+ const size_t ids_dim_size,
const size_t right_size
) {
const size_t *dims = info;
const size_t *strides = info + num_dims;
bool b = is_contiguous(num_dims, dims, strides);
for (unsigned int dst_i = blockIdx.x * blockDim.x + threadIdx.x; dst_i < numel; dst_i += blockDim.x * gridDim.x) {
- unsigned int left_i = dst_i / (dim_size * right_size);
- unsigned int id_i = dst_i / right_size % dim_size;
+ unsigned int left_i = dst_i / (ids_dim_size * right_size);
+ unsigned int id_i = dst_i / right_size % ids_dim_size;
unsigned int right_i = dst_i % right_size;
- unsigned int src_i = left_i * (dim_size * right_size) + ids[id_i] * right_size + right_i;
+ unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i;
unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides);
out[dst_i] = inp[strided_i];
}
@@ -37,9 +38,10 @@ extern "C" __global__ void FN_NAME( \
const TYPENAME *inp, \
TYPENAME *out, \
const size_t left_size, \
- const size_t dim_size, \
+ const size_t src_dim_size, \
+ const size_t ids_dim_size, \
const size_t right_size \
-) { index_select(numel, num_dims, info, ids, inp, out, left_size, dim_size, right_size); } \
+) { index_select(numel, num_dims, info, ids, inp, out, left_size, src_dim_size, ids_dim_size, right_size); } \
template<typename T, typename I>
__device__ void gather(