summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/cuda_backend.rs58
1 files changed, 7 insertions, 51 deletions
diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs
index 8234eeac..aa9a55f1 100644
--- a/src/cuda_backend.rs
+++ b/src/cuda_backend.rs
@@ -1,4 +1,5 @@
use crate::{CpuStorage, DType, Shape};
+use candle_kernels as kernels;
use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig};
/// cudarc related errors
@@ -22,49 +23,6 @@ type Result<T> = std::result::Result<T, CudaError>;
#[derive(Debug, Clone)]
pub struct CudaDevice(std::sync::Arc<cudarc::driver::CudaDevice>);
-// TODO: Switch to pre-compiled PTX kernels rather than compiling on the fly.
-const AFFINE_CU: &str = r#"
-extern "C" __global__ void affine_f32(
- const size_t numel,
- const float *x,
- float *y,
- const float mul,
- const float add
-) {
- unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
- if (i >= numel) {
- return;
- }
- y[i] = x[i] * mul + add;
-}
-
-extern "C" __global__ void affine_f64(
- const size_t numel,
- const double *x,
- double *y,
- const double mul,
- const double add
-) {
- unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
- if (i >= numel) {
- return;
- }
- y[i] = x[i] * mul + add;
-}
-"#;
-
-const FILL_CU: &str = r#"
-template<typename T>
-__device__ void fill_with(T *buf, T value, const size_t numel) {
- for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
- buf[i] = value;
- }
-}
-extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
-extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); }
-extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); }
-"#;
-
impl CudaDevice {
pub(crate) fn new(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new(ordinal)?;
@@ -97,7 +55,7 @@ impl CudaDevice {
DType::F32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { dev.alloc::<f32>(elem_count) }?;
- let func = self.get_or_load_func("fill_f32", FILL_CU)?;
+ let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
let params = (&data, v as f32, elem_count);
unsafe { func.launch(cfg, params) }?;
Ok(CudaStorage::F32(data))
@@ -105,7 +63,7 @@ impl CudaDevice {
DType::F64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { dev.alloc::<f64>(elem_count) }?;
- let func = self.get_or_load_func("fill_f64", FILL_CU)?;
+ let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
let params = (&data, v, elem_count);
unsafe { func.launch(cfg, params) }?;
Ok(CudaStorage::F64(data))
@@ -133,13 +91,11 @@ impl CudaDevice {
fn get_or_load_func(
&self,
module_name: &'static str,
- source: &'static str,
+ ptx: &'static str,
) -> Result<CudaFunction> {
let dev = &self.0;
if !dev.has_func(module_name, module_name) {
- // TODO: Pre-compile and load rather than compiling here.
- let ptx = cudarc::nvrtc::compile_ptx(source)?;
- dev.load_ptx(ptx, module_name, &[module_name])?;
+ dev.load_ptx(ptx.into(), module_name, &[module_name])?;
}
dev.get_func(module_name, module_name)
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
@@ -185,7 +141,7 @@ impl CudaStorage {
let dev = self.device();
match self {
Self::F32(arg) => {
- let func = dev.get_or_load_func("affine_f32", AFFINE_CU)?;
+ let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
// SAFETY: if this function returns Ok(..), the kernel has been applied
// and has set the initially unset memory.
let out = unsafe { dev.0.alloc::<f32>(elem_count) }?;
@@ -195,7 +151,7 @@ impl CudaStorage {
Ok(Self::F32(out))
}
Self::F64(arg) => {
- let func = dev.get_or_load_func("affine_f64", AFFINE_CU)?;
+ let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?;
// SAFETY: if this function returns Ok(..), the kernel has been applied
// and has set the initially unset memory.
let out = unsafe { dev.0.alloc::<f64>(elem_count) }?;