summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-12-31 09:06:10 +0100
committerGitHub <noreply@github.com>2024-12-31 09:06:10 +0100
commite38e2a85dd21cbb07dbca381ac3755f2b7909605 (patch)
tree8ec060b3f83c8786a83316217351d04bbcc99b87
parent460616fc845f8b8540d00e4ef00bcc38f5cdbf0e (diff)
downloadcandle-e38e2a85dd21cbb07dbca381ac3755f2b7909605.tar.gz
candle-e38e2a85dd21cbb07dbca381ac3755f2b7909605.tar.bz2
candle-e38e2a85dd21cbb07dbca381ac3755f2b7909605.zip
Fix a cuda warning. (#2693)
-rw-r--r--candle-core/src/sort.rs83
1 files changed, 44 insertions, 39 deletions
diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs
index 614a37fe..0ebb1835 100644
--- a/candle-core/src/sort.rs
+++ b/candle-core/src/sort.rs
@@ -52,6 +52,49 @@ impl ArgSort {
}
}
+#[cfg(feature = "cuda")]
+mod cuda {
+ use super::*;
+ use crate::cuda_backend::cudarc::driver::{
+ CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
+ };
+ use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
+ use crate::{CudaDevice, WithDType};
+
+ impl crate::cuda_backend::Map1Any for ArgSort {
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
+ &self,
+ src: &CudaSlice<T>,
+ dev: &CudaDevice,
+ layout: &crate::Layout,
+ _wrap: W,
+ ) -> Result<S> {
+ let slice = match layout.contiguous_offsets() {
+ None => crate::bail!("input has to be contiguous"),
+ Some((o1, o2)) => src.slice(o1..o2),
+ };
+ let elem_count = layout.shape().elem_count();
+ let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
+ let func = if self.asc {
+ dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
+ } else {
+ dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
+ };
+ let ncols = self.last_dim;
+ let nrows = elem_count / ncols;
+ let ncols_pad = next_power_of_2(ncols);
+ let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
+ let cfg = LaunchConfig {
+ grid_dim: (1, nrows as u32, 1),
+ block_dim: (ncols_pad as u32, 1, 1),
+ shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
+ };
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(S::U32(dst))
+ }
+ }
+}
+
impl crate::CustomOp1 for ArgSort {
fn name(&self) -> &'static str {
"argsort"
@@ -81,46 +124,8 @@ impl crate::CustomOp1 for ArgSort {
storage: &crate::CudaStorage,
layout: &crate::Layout,
) -> Result<(crate::CudaStorage, crate::Shape)> {
- use crate::cuda_backend::cudarc::driver::{
- CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
- };
- use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, Map1Any, WrapErr};
- use crate::{CudaDevice, WithDType};
-
- impl Map1Any for ArgSort {
- fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
- &self,
- src: &CudaSlice<T>,
- dev: &CudaDevice,
- layout: &crate::Layout,
- _wrap: W,
- ) -> Result<S> {
- let slice = match layout.contiguous_offsets() {
- None => crate::bail!("input has to be contiguous"),
- Some((o1, o2)) => src.slice(o1..o2),
- };
- let elem_count = layout.shape().elem_count();
- let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
- let func = if self.asc {
- dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
- } else {
- dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
- };
- let ncols = self.last_dim;
- let nrows = elem_count / ncols;
- let ncols_pad = next_power_of_2(ncols);
- let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
- let cfg = LaunchConfig {
- grid_dim: (1, nrows as u32, 1),
- block_dim: (ncols_pad as u32, 1, 1),
- shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
- };
- unsafe { func.launch(cfg, params) }.w()?;
- Ok(S::U32(dst))
- }
- }
-
use crate::backend::BackendStorage;
+ use crate::cuda_backend::Map1Any;
let dev = storage.device();
let slice = self.map(&storage.slice, dev, layout)?;
let dst = crate::cuda_backend::CudaStorage {