diff options
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r-- | candle-core/src/cuda_backend.rs | 223 |
1 files changed, 207 insertions, 16 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 663f2319..00fd1d04 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1,7 +1,7 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType}; -use candle_kernels as kernels; +pub use candle_kernels as kernels; pub use cudarc; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::driver::{ @@ -312,6 +312,13 @@ impl BackendDevice for CudaDevice { // cudarc changes. let elem_count = shape.elem_count(); let curand = self.curand.lock().unwrap(); + // curand can only generate an odd number of values. + // https://github.com/huggingface/candle/issues/734 + let elem_count_round = if elem_count % 2 == 1 { + elem_count + 1 + } else { + elem_count + }; let slice = match dtype { DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { Err(CudaError::UnsupportedDtype { @@ -321,7 +328,7 @@ impl BackendDevice for CudaDevice { .w()? } DType::F32 => { - let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?; + let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?; curand .0 .fill_with_normal(&mut data, mean as f32, std as f32) @@ -329,7 +336,7 @@ impl BackendDevice for CudaDevice { CudaStorageSlice::F32(data) } DType::F64 => { - let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?; + let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?; curand.0.fill_with_normal(&mut data, mean, std).w()?; CudaStorageSlice::F64(data) } @@ -383,7 +390,7 @@ impl BackendDevice for CudaDevice { } #[derive(Debug)] -enum CudaStorageSlice { +pub enum CudaStorageSlice { U8(CudaSlice<u8>), U32(CudaSlice<u32>), I64(CudaSlice<i64>), @@ -394,7 +401,7 @@ enum CudaStorageSlice { } type S = CudaStorageSlice; -trait Map1 { +pub trait Map1 { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( &self, src: &CudaSlice<T>, @@ -416,7 +423,7 @@ trait Map1 { } } -trait Map2 { +pub trait Map2 { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( &self, src1: &CudaSlice<T>, @@ -441,7 +448,7 @@ trait Map2 { } } -trait Map2InPlace { +pub trait Map2InPlace { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( &self, dst: &mut CudaSlice<T>, @@ -472,7 +479,7 @@ trait Map2InPlace { } } -trait Map1Any { +pub trait Map1Any { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>( &self, src: &CudaSlice<T>, @@ -495,7 +502,7 @@ trait Map1Any { } } -trait Map2Any { +pub trait Map2Any { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( &self, src1: &CudaSlice<T>, @@ -532,7 +539,7 @@ impl Map1 for Clone { } } -fn kernel_name<T: WithDType>(root: &str) -> String { +pub fn kernel_name<T: WithDType>(root: &str) -> String { let dtype = T::DTYPE.as_str(); format!("{root}_{dtype}") } @@ -593,6 +600,105 @@ impl Map1 for Elu { } } +struct Im2Col1D { + l_k: usize, + stride: usize, + dilation: usize, + padding: usize, +} + +impl Im2Col1D { + fn l_out(&self, l: usize) -> usize { + (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1 + } +} + +impl Map1 for Im2Col1D { + fn f<T: DeviceRepr + WithDType>( + &self, + src: &CudaSlice<T>, + dev: &CudaDevice, + layout: &Layout, + ) -> Result<CudaSlice<T>> { + let shape = layout.shape(); + let dims = shape.dims(); + let l_out = self.l_out(dims[2]); + let dst_el = dims[0] * l_out * dims[1] * self.l_k; + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), kernels::CONV)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?; + let params = ( + dst_el, + l_out, + self.l_k, + self.stride, + self.padding, + self.dilation, + &ds, + src, + &dst, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } +} + +struct Im2Col { + h_k: usize, + w_k: usize, + stride: usize, + dilation: usize, + padding: usize, +} + +impl Im2Col { + fn hw_out(&self, h: usize, w: usize) -> (usize, usize) { + let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1; + let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1; + (h_out, w_out) + } +} + +impl Map1 for Im2Col { + fn f<T: DeviceRepr + WithDType>( + &self, + src: &CudaSlice<T>, + dev: &CudaDevice, + layout: &Layout, + ) -> Result<CudaSlice<T>> { + let shape = layout.shape(); + let dims = shape.dims(); + let (h_out, w_out) = self.hw_out(dims[2], dims[3]); + let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k; + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::<T>("im2col"), kernels::CONV)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?; + let params = ( + dst_el, + h_out, + w_out, + self.h_k, + self.w_k, + self.stride, + self.padding, + self.dilation, + &ds, + src, + &dst, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } +} + struct Powf(f64); impl Map1 for Powf { fn f<T: DeviceRepr + WithDType>( @@ -1310,8 +1416,8 @@ fn slice_src_and_dst<'a, T>( #[derive(Debug)] pub struct CudaStorage { - slice: CudaStorageSlice, - device: CudaDevice, + pub slice: CudaStorageSlice, + pub device: CudaDevice, } pub trait CudaDType: Sized { @@ -1650,9 +1756,46 @@ impl BackendStorage for CudaStorage { kernel_l: &Layout, params: &crate::conv::ParamsConv1D, ) -> Result<Self> { + const USE_IM2COL_CONV1D: bool = true; + let device = self.device().clone(); - let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; - Ok(Self { slice, device }) + if !USE_IM2COL_CONV1D { + let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; + return Ok(Self { slice, device }); + } + + let col = Im2Col1D { + l_k: params.k_size, + stride: params.stride, + dilation: params.dilation, + padding: params.padding, + } + .map(&self.slice, &device, l)?; + let col = Self { slice: col, device }; + let l_out = params.l_out(); + let b = params.b_size; + let n = params.c_out; + let k = params.k_size * params.c_in; + let m = l_out; + let col_l = Layout::contiguous((b, m, k)); + let res = if kernel_l.is_contiguous() { + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + } else { + // Make the kernel contiguous if not already the case. + let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + }; + let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?; + let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + res.copy_strided_src(&mut res_t, 0, &res_l)?; + Ok(res_t) } #[cfg(not(feature = "cudnn"))] @@ -1663,9 +1806,50 @@ impl BackendStorage for CudaStorage { kernel_l: &Layout, params: &crate::conv::ParamsConv2D, ) -> Result<Self> { + const USE_IM2COL_CONV2D: bool = true; + let device = self.device().clone(); - let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; - Ok(Self { slice, device }) + if !USE_IM2COL_CONV2D { + let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; + return Ok(Self { slice, device }); + } + + let col = Im2Col { + h_k: params.k_h, + w_k: params.k_w, + stride: params.stride, + dilation: params.dilation, + padding: params.padding, + } + .map(&self.slice, &device, l)?; + let col = Self { slice: col, device }; + let h_out = params.out_h(); + let w_out = params.out_w(); + let b = params.b_size; + let n = params.c_out; + let k = params.k_h * params.k_w * params.c_in; + let m = h_out * w_out; + let col_l = Layout::contiguous((b, m, k)); + let res = if kernel_l.is_contiguous() { + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + } else { + // Make the kernel contiguous if not already the case. + let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + }; + let res_l = Layout::contiguous((b, h_out, w_out, n)) + .transpose(1, 2)? + .transpose(1, 3)?; + let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + res.copy_strided_src(&mut res_t, 0, &res_l)?; + Ok(res_t) } #[cfg(feature = "cudnn")] @@ -1770,6 +1954,10 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } + fn upsample_nearest1d(&self, _: &Layout, _out_sz: usize) -> Result<Self> { + crate::bail!("upsample-nearest1d is not supported on cuda") + } + fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result<Self> { let device = self.device().clone(); let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?; @@ -1889,6 +2077,9 @@ impl BackendStorage for CudaStorage { let src_shape = src_l.shape(); let dims = src_shape.dims(); let el_count = src_shape.elem_count(); + if el_count == 0 { + return Ok(()); + } let cfg = LaunchConfig::for_num_elems(el_count as u32); let dev = &self.device; let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?; |