summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-06-22 08:33:32 +0100
committerlaurent <laurent.mazare@gmail.com>2023-06-22 08:33:32 +0100
commit0a758ffa0523629336e7224fa181dd1e76d8919c (patch)
tree8338a53820ea5a124f97ea9b1fca91788f0f4e4f /src
parentfc26bab3ede511c3c4d2f1afb15f58eb6c588c94 (diff)
downloadcandle-0a758ffa0523629336e7224fa181dd1e76d8919c.tar.gz
candle-0a758ffa0523629336e7224fa181dd1e76d8919c.tar.bz2
candle-0a758ffa0523629336e7224fa181dd1e76d8919c.zip
Add the fill kernel and use it for 'ones'.
Diffstat (limited to 'src')
-rw-r--r--src/cuda_backend.rs51
-rw-r--r--src/device.rs5
-rw-r--r--src/dummy_cuda_backend.rs4
3 files changed, 56 insertions, 4 deletions
diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs
index 7858e542..5706a2e6 100644
--- a/src/cuda_backend.rs
+++ b/src/cuda_backend.rs
@@ -23,6 +23,18 @@ extern "C" __global__ void affine_f32(
}
"#;
+const FILL_CU: &str = r#"
+template<typename T>
+__device__ void fill_with(T *buf, T value, const size_t numel) {
+ for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
+ buf[i] = value;
+ }
+}
+extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
+extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); }
+extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); }
+"#;
+
impl CudaDevice {
pub(crate) fn new(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new(ordinal)?;
@@ -47,6 +59,45 @@ impl CudaDevice {
}
}
+ pub(crate) fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
+ let elem_count = shape.elem_count();
+ let dev = &self.0;
+ match dtype {
+ DType::F32 => {
+ // SAFETY: Set later by running the fill kernel.
+ let data = unsafe { dev.alloc::<f32>(elem_count) }?;
+ let module_name = "fill_f32";
+ if !dev.has_func(module_name, module_name) {
+ let ptx = cudarc::nvrtc::compile_ptx(FILL_CU).unwrap();
+ dev.load_ptx(ptx, module_name, &[module_name])?;
+ }
+ let fwd_fn = dev.get_func(module_name, module_name).unwrap();
+ let cfg = LaunchConfig::for_num_elems(elem_count as u32);
+ let params = (&data, v as f32, elem_count);
+ unsafe { fwd_fn.launch(cfg, params) }?;
+ Ok(CudaStorage::F32(data))
+ }
+ DType::F64 => {
+ // SAFETY: Set later by running the fill kernel.
+ let data = unsafe { dev.alloc::<f64>(elem_count) }?;
+ let module_name = "fill_f64";
+ if !dev.has_func(module_name, module_name) {
+ let ptx = cudarc::nvrtc::compile_ptx(FILL_CU).unwrap();
+ dev.load_ptx(ptx, module_name, &[module_name])?;
+ }
+ let fwd_fn = dev.get_func(module_name, module_name).unwrap();
+ let cfg = LaunchConfig::for_num_elems(elem_count as u32);
+ let params = (&data, v, elem_count);
+ unsafe { fwd_fn.launch(cfg, params) }?;
+ Ok(CudaStorage::F64(data))
+ }
+ }
+ }
+
+ pub(crate) fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
+ self.const_impl(1., shape, dtype)
+ }
+
pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
match storage {
CpuStorage::F32(storage) => {
diff --git a/src/device.rs b/src/device.rs
index e522cd42..ab7bad26 100644
--- a/src/device.rs
+++ b/src/device.rs
@@ -82,10 +82,7 @@ impl Device {
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
- // TODO: Instead of allocating memory on the host and transfering it,
- // allocate some zeros on the device and use a shader to set them to 1.
- let storage = CpuStorage::ones_impl(shape, dtype);
- let storage = device.cuda_from_cpu_storage(&storage)?;
+ let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
}
diff --git a/src/dummy_cuda_backend.rs b/src/dummy_cuda_backend.rs
index 85b5f598..2eb393c1 100644
--- a/src/dummy_cuda_backend.rs
+++ b/src/dummy_cuda_backend.rs
@@ -25,6 +25,10 @@ impl CudaDevice {
Err(Error::NotCompiledWithCudaSupport)
}
+ pub(crate) fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<CudaStorage> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
+
pub(crate) fn cuda_from_cpu_storage(&self, _: &CpuStorage) -> Result<CudaStorage> {
Err(Error::NotCompiledWithCudaSupport)
}