diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-12-31 09:06:10 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-12-31 09:06:10 +0100 |
commit | e38e2a85dd21cbb07dbca381ac3755f2b7909605 (patch) | |
tree | 8ec060b3f83c8786a83316217351d04bbcc99b87 | |
parent | 460616fc845f8b8540d00e4ef00bcc38f5cdbf0e (diff) | |
download | candle-e38e2a85dd21cbb07dbca381ac3755f2b7909605.tar.gz candle-e38e2a85dd21cbb07dbca381ac3755f2b7909605.tar.bz2 candle-e38e2a85dd21cbb07dbca381ac3755f2b7909605.zip |
Fix a cuda warning. (#2693)
-rw-r--r-- | candle-core/src/sort.rs | 83 |
1 files changed, 44 insertions, 39 deletions
diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs index 614a37fe..0ebb1835 100644 --- a/candle-core/src/sort.rs +++ b/candle-core/src/sort.rs @@ -52,6 +52,49 @@ impl ArgSort { } } +#[cfg(feature = "cuda")] +mod cuda { + use super::*; + use crate::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits, + }; + use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr}; + use crate::{CudaDevice, WithDType}; + + impl crate::cuda_backend::Map1Any for ArgSort { + fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>( + &self, + src: &CudaSlice<T>, + dev: &CudaDevice, + layout: &crate::Layout, + _wrap: W, + ) -> Result<S> { + let slice = match layout.contiguous_offsets() { + None => crate::bail!("input has to be contiguous"), + Some((o1, o2)) => src.slice(o1..o2), + }; + let elem_count = layout.shape().elem_count(); + let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?; + let func = if self.asc { + dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)? + } else { + dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)? + }; + let ncols = self.last_dim; + let nrows = elem_count / ncols; + let ncols_pad = next_power_of_2(ncols); + let params = (&slice, &dst, ncols as i32, ncols_pad as i32); + let cfg = LaunchConfig { + grid_dim: (1, nrows as u32, 1), + block_dim: (ncols_pad as u32, 1, 1), + shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32, + }; + unsafe { func.launch(cfg, params) }.w()?; + Ok(S::U32(dst)) + } + } +} + impl crate::CustomOp1 for ArgSort { fn name(&self) -> &'static str { "argsort" @@ -81,46 +124,8 @@ impl crate::CustomOp1 for ArgSort { storage: &crate::CudaStorage, layout: &crate::Layout, ) -> Result<(crate::CudaStorage, crate::Shape)> { - use crate::cuda_backend::cudarc::driver::{ - CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits, - }; - use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, Map1Any, WrapErr}; - use crate::{CudaDevice, WithDType}; - - impl Map1Any for ArgSort { - fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>( - &self, - src: &CudaSlice<T>, - dev: &CudaDevice, - layout: &crate::Layout, - _wrap: W, - ) -> Result<S> { - let slice = match layout.contiguous_offsets() { - None => crate::bail!("input has to be contiguous"), - Some((o1, o2)) => src.slice(o1..o2), - }; - let elem_count = layout.shape().elem_count(); - let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?; - let func = if self.asc { - dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)? - } else { - dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)? - }; - let ncols = self.last_dim; - let nrows = elem_count / ncols; - let ncols_pad = next_power_of_2(ncols); - let params = (&slice, &dst, ncols as i32, ncols_pad as i32); - let cfg = LaunchConfig { - grid_dim: (1, nrows as u32, 1), - block_dim: (ncols_pad as u32, 1, 1), - shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32, - }; - unsafe { func.launch(cfg, params) }.w()?; - Ok(S::U32(dst)) - } - } - use crate::backend::BackendStorage; + use crate::cuda_backend::Map1Any; let dev = storage.device(); let slice = self.map(&storage.slice, dev, layout)?; let dst = crate::cuda_backend::CudaStorage { |