summaryrefslogtreecommitdiff
path: root/src/cuda_backend.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/cuda_backend.rs')
-rw-r--r--src/cuda_backend.rs109
1 files changed, 74 insertions, 35 deletions
diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs
index 5706a2e6..8234eeac 100644
--- a/src/cuda_backend.rs
+++ b/src/cuda_backend.rs
@@ -1,7 +1,23 @@
-use crate::{CpuStorage, DType, Error, Result, Shape};
-use cudarc::driver::{CudaSlice, LaunchAsync, LaunchConfig};
+use crate::{CpuStorage, DType, Shape};
+use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig};
-pub type CudaError = cudarc::driver::DriverError;
+/// cudarc related errors
+#[derive(thiserror::Error, Debug)]
+pub enum CudaError {
+ #[error(transparent)]
+ Cuda(#[from] cudarc::driver::DriverError),
+
+ #[error(transparent)]
+ Compiler(#[from] cudarc::nvrtc::CompileError),
+
+ #[error("{op} only supports contiguous tensors")]
+ RequiresContiguous { op: &'static str },
+
+ #[error("missing kernel '{module_name}'")]
+ MissingKernel { module_name: &'static str },
+}
+
+type Result<T> = std::result::Result<T, CudaError>;
#[derive(Debug, Clone)]
pub struct CudaDevice(std::sync::Arc<cudarc::driver::CudaDevice>);
@@ -21,6 +37,20 @@ extern "C" __global__ void affine_f32(
}
y[i] = x[i] * mul + add;
}
+
+extern "C" __global__ void affine_f64(
+ const size_t numel,
+ const double *x,
+ double *y,
+ const double mul,
+ const double add
+) {
+ unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
+ if (i >= numel) {
+ return;
+ }
+ y[i] = x[i] * mul + add;
+}
"#;
const FILL_CU: &str = r#"
@@ -61,34 +91,23 @@ impl CudaDevice {
pub(crate) fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
+ let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let dev = &self.0;
match dtype {
DType::F32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { dev.alloc::<f32>(elem_count) }?;
- let module_name = "fill_f32";
- if !dev.has_func(module_name, module_name) {
- let ptx = cudarc::nvrtc::compile_ptx(FILL_CU).unwrap();
- dev.load_ptx(ptx, module_name, &[module_name])?;
- }
- let fwd_fn = dev.get_func(module_name, module_name).unwrap();
- let cfg = LaunchConfig::for_num_elems(elem_count as u32);
+ let func = self.get_or_load_func("fill_f32", FILL_CU)?;
let params = (&data, v as f32, elem_count);
- unsafe { fwd_fn.launch(cfg, params) }?;
+ unsafe { func.launch(cfg, params) }?;
Ok(CudaStorage::F32(data))
}
DType::F64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { dev.alloc::<f64>(elem_count) }?;
- let module_name = "fill_f64";
- if !dev.has_func(module_name, module_name) {
- let ptx = cudarc::nvrtc::compile_ptx(FILL_CU).unwrap();
- dev.load_ptx(ptx, module_name, &[module_name])?;
- }
- let fwd_fn = dev.get_func(module_name, module_name).unwrap();
- let cfg = LaunchConfig::for_num_elems(elem_count as u32);
+ let func = self.get_or_load_func("fill_f64", FILL_CU)?;
let params = (&data, v, elem_count);
- unsafe { fwd_fn.launch(cfg, params) }?;
+ unsafe { func.launch(cfg, params) }?;
Ok(CudaStorage::F64(data))
}
}
@@ -110,6 +129,23 @@ impl CudaDevice {
}
}
}
+
+ fn get_or_load_func(
+ &self,
+ module_name: &'static str,
+ source: &'static str,
+ ) -> Result<CudaFunction> {
+ let dev = &self.0;
+ if !dev.has_func(module_name, module_name) {
+ // TODO: Pre-compile and load rather than compiling here.
+ let ptx = cudarc::nvrtc::compile_ptx(source)?;
+ dev.load_ptx(ptx, module_name, &[module_name])?;
+ }
+ dev.get_func(module_name, module_name)
+ // Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
+ // able to only build the error value if needed.
+ .ok_or(CudaError::MissingKernel { module_name })
+ }
}
#[derive(Debug, Clone)]
@@ -140,30 +176,33 @@ impl CudaStorage {
mul: f64,
add: f64,
) -> Result<Self> {
+ if !shape.is_contiguous(stride) {
+ return Err(CudaError::RequiresContiguous { op: "affine" });
+ }
+
+ let elem_count = shape.elem_count();
+ let cfg = LaunchConfig::for_num_elems(elem_count as u32);
+ let dev = self.device();
match self {
Self::F32(arg) => {
- if !shape.is_contiguous(stride) {
- return Err(Error::RequiresContiguous { op: "affine" });
- }
- let dev = arg.device();
- let module_name = "affine_f32";
- if !dev.has_func(module_name, module_name) {
- let ptx = cudarc::nvrtc::compile_ptx(AFFINE_CU).unwrap();
- dev.load_ptx(ptx, module_name, &[module_name])?;
- }
- let elem_count = shape.elem_count();
- let fwd_fn = dev.get_func(module_name, module_name).unwrap();
- let cfg = LaunchConfig::for_num_elems(elem_count as u32);
+ let func = dev.get_or_load_func("affine_f32", AFFINE_CU)?;
// SAFETY: if this function returns Ok(..), the kernel has been applied
// and has set the initially unset memory.
- let out = unsafe { dev.alloc::<f32>(elem_count) }?;
+ let out = unsafe { dev.0.alloc::<f32>(elem_count) }?;
let params = (elem_count, arg, &out, mul as f32, add as f32);
// SAFETY: well, well, well...
- unsafe { fwd_fn.launch(cfg, params) }?;
+ unsafe { func.launch(cfg, params) }?;
Ok(Self::F32(out))
}
- Self::F64(_) => {
- todo!()
+ Self::F64(arg) => {
+ let func = dev.get_or_load_func("affine_f64", AFFINE_CU)?;
+ // SAFETY: if this function returns Ok(..), the kernel has been applied
+ // and has set the initially unset memory.
+ let out = unsafe { dev.0.alloc::<f64>(elem_count) }?;
+ let params = (elem_count, arg, &out, mul, add);
+ // SAFETY: well, well, well...
+ unsafe { func.launch(cfg, params) }?;
+ Ok(Self::F64(out))
}
}
}