diff options
-rw-r--r-- | candle-core/src/cpu_backend/mod.rs | 4 | ||||
-rw-r--r-- | candle-core/src/cuda_backend/mod.rs | 75 | ||||
-rw-r--r-- | candle-kernels/src/conv.cu | 65 |
3 files changed, 140 insertions, 4 deletions
diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 299b1e6e..18b73e9b 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -10,7 +10,7 @@ pub use utils::{ }; const USE_IM2COL_CONV1D: bool = true; -const USE_IM2COL_CONV1D_TR: bool = true; +const USE_COL2IM_CONV1D_TR: bool = true; const USE_IM2COL_CONV2D: bool = true; // TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator + @@ -2249,7 +2249,7 @@ impl BackendStorage for CpuStorage { && params.dilation == 1 && params.padding == 0 && params.output_padding == 0; - if USE_IM2COL_CONV1D_TR && can_use_col2im { + if USE_COL2IM_CONV1D_TR && can_use_col2im { let (b_size, c_in, l_in) = l.shape().dims3()?; let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?; if !kernel_l.is_contiguous() { diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 1ea9beaf..88f325f4 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -630,6 +630,31 @@ impl<'a> Map2 for Conv2D<'a> { } } +struct Col2Im1D { + stride: usize, +} + +impl Map1 for Col2Im1D { + fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( + &self, + col: &CudaSlice<T>, + dev: &CudaDevice, + l: &Layout, + ) -> Result<CudaSlice<T>> { + let (b_size, l_in, c_out, k_size) = l.shape().dims4()?; + let stride = self.stride; + let l_out = (l_in - 1) * stride + k_size; + let dst_el = b_size * c_out * l_out; + let mut im = unsafe { dev.alloc::<T>(dst_el) }.w()?; + + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let params = (dst_el, l_out, l_in, c_out, k_size, stride, col, &mut im); + let func = dev.get_or_load_func(&kernel_name::<T>("col2im1d"), kernels::CONV)?; + unsafe { func.launch(cfg, params) }.w()?; + Ok(im) + } +} + struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D); impl<'a> Map2 for ConvTranspose1D<'a> { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( @@ -1366,9 +1391,55 @@ impl BackendStorage for CudaStorage { kernel_l: &Layout, params: &crate::conv::ParamsConvTranspose1D, ) -> Result<Self> { + const USE_COL2IM_CONV1D_TR: bool = true; + let device = self.device().clone(); - let slice = - ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; + let can_use_col2im = kernel_l.is_contiguous() + && params.dilation == 1 + && params.padding == 0 + && params.output_padding == 0; + let slice = if USE_COL2IM_CONV1D_TR && can_use_col2im { + let (b_size, c_in, l_in) = l.shape().dims3()?; + let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?; + if !kernel_l.is_contiguous() { + crate::bail!( + "convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}" + ) + } + if c_in != c_in2 { + crate::bail!( + "convtr1d: shape mismatch on c_in {:?} {:?}", + l.shape(), + kernel_l.shape() + ) + } + let col = { + // This merges the last two dimensions of the kernel together. + let kernel_l_mm = Layout::new( + (b_size, c_in, k_size * c_out).into(), + vec![0, k_size * c_out, 1], + kernel_l.start_offset(), + ); + self.matmul( + kernel, + ( + b_size, + /* m */ l_in, + /* n */ c_out * k_size, + /* k */ c_in, + ), + &l.transpose(1, 2)?, + &kernel_l_mm, + )? + }; + let col_l = Layout::contiguous((b_size, l_in, c_out, k_size)); + Col2Im1D { + stride: params.stride, + } + .map(&col.slice, &device, &col_l)? + } else { + ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)? + }; Ok(Self { slice, device }) } diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index fed920f1..fa834faa 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -98,6 +98,50 @@ __device__ void im2col1d( } template <typename T> +__device__ void col2im1d( + const size_t dst_el, + const size_t l_out, + const size_t l_in, + const size_t c_out, + const size_t k_size, + const size_t stride, + const T *src, + T *dst +) { + const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; + // src: (b_size, l_in, c_out, l_k) + // dst: (b_size, c_out, l_out) + if (dst_i >= dst_el) { + return; + } + + const size_t dst_s0 = c_out * l_out; + const size_t dst_s1 = l_out; + const size_t src_s0 = c_out * k_size * l_in; + const size_t src_s1 = c_out * k_size; + const size_t src_s2 = k_size; + + size_t tmp_dst_i = dst_i; + const size_t b_idx = tmp_dst_i / dst_s0; + tmp_dst_i -= b_idx * dst_s0; + const size_t c_idx = tmp_dst_i / dst_s1; + tmp_dst_i -= c_idx * dst_s1; + const int l_out_idx = tmp_dst_i; + + dst[dst_i] = static_cast<T>(0); + + int l_in_idx = l_out_idx / stride; + int k0 = l_out_idx - l_in_idx * stride; + // l_out_idx = l_in_idx * stride + k0 + for (; k0 < k_size && l_in_idx >= 0; k0 += stride, --l_in_idx) { + if (l_in_idx < l_in) { + const size_t src_i = b_idx * src_s0 + l_in_idx * src_s1 + c_idx * src_s2 + k0; + dst[dst_i] += src[src_i]; + } + } +} + +template <typename T> __device__ void im2col( const size_t dst_numel, const size_t h_out, @@ -542,6 +586,20 @@ extern "C" __global__ void FN_NAME( \ im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \ } \ +#define COL2IM1D_OP(TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t dst_el, \ + const size_t l_out, \ + const size_t l_in, \ + const size_t c_out, \ + const size_t k_size, \ + const size_t stride, \ + const TYPENAME *src, \ + TYPENAME *dst \ +) { \ + col2im1d<TYPENAME>(dst_el, l_out, l_in, c_out, k_size, stride, src, dst); \ +} \ + #define IM2COL_OP(TYPENAME, FN_NAME) \ extern "C" __global__ void FN_NAME( \ const size_t dst_numel, \ @@ -643,6 +701,7 @@ MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16) UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16) IM2COL_OP(__nv_bfloat16, im2col_bf16) IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16) +COL2IM1D_OP(__nv_bfloat16, col2im1d_bf16) #endif #if __CUDA_ARCH__ >= 530 @@ -655,6 +714,7 @@ MAX_POOL2D_OP(__half, max_pool2d_f16) UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16) IM2COL_OP(__half, im2col_f16) IM2COL1D_OP(__half, im2col1d_f16) +COL2IM1D_OP(__half, col2im1d_f16) #endif CONV1D_OP(float, float, conv1d_f32) @@ -701,3 +761,8 @@ IM2COL1D_OP(float, im2col1d_f32) IM2COL1D_OP(double, im2col1d_f64) IM2COL1D_OP(uint8_t, im2col1d_u8) IM2COL1D_OP(uint32_t, im2col1d_u32) + +COL2IM1D_OP(float, col2im1d_f32) +COL2IM1D_OP(double, col2im1d_f64) +COL2IM1D_OP(uint8_t, col2im1d_u8) +COL2IM1D_OP(uint32_t, col2im1d_u32) |