summaryrefslogtreecommitdiff
path: root/src/cuda_backend.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/cuda_backend.rs')
-rw-r--r--src/cuda_backend.rs51
1 files changed, 51 insertions, 0 deletions
diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs
index 7858e542..5706a2e6 100644
--- a/src/cuda_backend.rs
+++ b/src/cuda_backend.rs
@@ -23,6 +23,18 @@ extern "C" __global__ void affine_f32(
}
"#;
+const FILL_CU: &str = r#"
+template<typename T>
+__device__ void fill_with(T *buf, T value, const size_t numel) {
+ for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
+ buf[i] = value;
+ }
+}
+extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
+extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); }
+extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); }
+"#;
+
impl CudaDevice {
pub(crate) fn new(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new(ordinal)?;
@@ -47,6 +59,45 @@ impl CudaDevice {
}
}
+ pub(crate) fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
+ let elem_count = shape.elem_count();
+ let dev = &self.0;
+ match dtype {
+ DType::F32 => {
+ // SAFETY: Set later by running the fill kernel.
+ let data = unsafe { dev.alloc::<f32>(elem_count) }?;
+ let module_name = "fill_f32";
+ if !dev.has_func(module_name, module_name) {
+ let ptx = cudarc::nvrtc::compile_ptx(FILL_CU).unwrap();
+ dev.load_ptx(ptx, module_name, &[module_name])?;
+ }
+ let fwd_fn = dev.get_func(module_name, module_name).unwrap();
+ let cfg = LaunchConfig::for_num_elems(elem_count as u32);
+ let params = (&data, v as f32, elem_count);
+ unsafe { fwd_fn.launch(cfg, params) }?;
+ Ok(CudaStorage::F32(data))
+ }
+ DType::F64 => {
+ // SAFETY: Set later by running the fill kernel.
+ let data = unsafe { dev.alloc::<f64>(elem_count) }?;
+ let module_name = "fill_f64";
+ if !dev.has_func(module_name, module_name) {
+ let ptx = cudarc::nvrtc::compile_ptx(FILL_CU).unwrap();
+ dev.load_ptx(ptx, module_name, &[module_name])?;
+ }
+ let fwd_fn = dev.get_func(module_name, module_name).unwrap();
+ let cfg = LaunchConfig::for_num_elems(elem_count as u32);
+ let params = (&data, v, elem_count);
+ unsafe { fwd_fn.launch(cfg, params) }?;
+ Ok(CudaStorage::F64(data))
+ }
+ }
+ }
+
+ pub(crate) fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
+ self.const_impl(1., shape, dtype)
+ }
+
pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
match storage {
CpuStorage::F32(storage) => {