diff options
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r-- | candle-core/src/cuda_backend.rs | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 1599425f..f7518067 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -891,7 +891,8 @@ impl<'a> Map1 for IndexSelect<'a> { }; let left_size: usize = src_l.dims()[..self.2].iter().product(); let right_size: usize = src_l.dims()[self.2 + 1..].iter().product(); - let dim_size = ids_shape.elem_count(); + let src_dim_size = src_l.dims()[self.2]; + let ids_dim_size = ids_shape.elem_count(); let dst_el = ids_shape.elem_count() * left_size * right_size; let cfg = LaunchConfig::for_num_elems(dst_el as u32); let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?; @@ -905,7 +906,8 @@ impl<'a> Map1 for IndexSelect<'a> { &src, &out, left_size, - dim_size, + src_dim_size, + ids_dim_size, right_size, ); // SAFETY: ffi. |