diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/cuda_backend.rs | 58 |
1 files changed, 7 insertions, 51 deletions
diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 8234eeac..aa9a55f1 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -1,4 +1,5 @@ use crate::{CpuStorage, DType, Shape}; +use candle_kernels as kernels; use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig}; /// cudarc related errors @@ -22,49 +23,6 @@ type Result<T> = std::result::Result<T, CudaError>; #[derive(Debug, Clone)] pub struct CudaDevice(std::sync::Arc<cudarc::driver::CudaDevice>); -// TODO: Switch to pre-compiled PTX kernels rather than compiling on the fly. -const AFFINE_CU: &str = r#" -extern "C" __global__ void affine_f32( - const size_t numel, - const float *x, - float *y, - const float mul, - const float add -) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= numel) { - return; - } - y[i] = x[i] * mul + add; -} - -extern "C" __global__ void affine_f64( - const size_t numel, - const double *x, - double *y, - const double mul, - const double add -) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= numel) { - return; - } - y[i] = x[i] * mul + add; -} -"#; - -const FILL_CU: &str = r#" -template<typename T> -__device__ void fill_with(T *buf, T value, const size_t numel) { - for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { - buf[i] = value; - } -} -extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); } -extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); } -extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); } -"#; - impl CudaDevice { pub(crate) fn new(ordinal: usize) -> Result<Self> { let device = cudarc::driver::CudaDevice::new(ordinal)?; @@ -97,7 +55,7 @@ impl CudaDevice { DType::F32 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { dev.alloc::<f32>(elem_count) }?; - let func = self.get_or_load_func("fill_f32", FILL_CU)?; + let func = self.get_or_load_func("fill_f32", kernels::FILL)?; let params = (&data, v as f32, elem_count); unsafe { func.launch(cfg, params) }?; Ok(CudaStorage::F32(data)) @@ -105,7 +63,7 @@ impl CudaDevice { DType::F64 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { dev.alloc::<f64>(elem_count) }?; - let func = self.get_or_load_func("fill_f64", FILL_CU)?; + let func = self.get_or_load_func("fill_f64", kernels::FILL)?; let params = (&data, v, elem_count); unsafe { func.launch(cfg, params) }?; Ok(CudaStorage::F64(data)) @@ -133,13 +91,11 @@ impl CudaDevice { fn get_or_load_func( &self, module_name: &'static str, - source: &'static str, + ptx: &'static str, ) -> Result<CudaFunction> { let dev = &self.0; if !dev.has_func(module_name, module_name) { - // TODO: Pre-compile and load rather than compiling here. - let ptx = cudarc::nvrtc::compile_ptx(source)?; - dev.load_ptx(ptx, module_name, &[module_name])?; + dev.load_ptx(ptx.into(), module_name, &[module_name])?; } dev.get_func(module_name, module_name) // Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is @@ -185,7 +141,7 @@ impl CudaStorage { let dev = self.device(); match self { Self::F32(arg) => { - let func = dev.get_or_load_func("affine_f32", AFFINE_CU)?; + let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?; // SAFETY: if this function returns Ok(..), the kernel has been applied // and has set the initially unset memory. let out = unsafe { dev.0.alloc::<f32>(elem_count) }?; @@ -195,7 +151,7 @@ impl CudaStorage { Ok(Self::F32(out)) } Self::F64(arg) => { - let func = dev.get_or_load_func("affine_f64", AFFINE_CU)?; + let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?; // SAFETY: if this function returns Ok(..), the kernel has been applied // and has set the initially unset memory. let out = unsafe { dev.0.alloc::<f64>(elem_count) }?; |