diff options
Diffstat (limited to 'candle-core/src/cuda_backend/mod.rs')
-rw-r--r-- | candle-core/src/cuda_backend/mod.rs | 75 |
1 files changed, 73 insertions, 2 deletions
diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 1ea9beaf..88f325f4 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -630,6 +630,31 @@ impl<'a> Map2 for Conv2D<'a> { } } +struct Col2Im1D { + stride: usize, +} + +impl Map1 for Col2Im1D { + fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( + &self, + col: &CudaSlice<T>, + dev: &CudaDevice, + l: &Layout, + ) -> Result<CudaSlice<T>> { + let (b_size, l_in, c_out, k_size) = l.shape().dims4()?; + let stride = self.stride; + let l_out = (l_in - 1) * stride + k_size; + let dst_el = b_size * c_out * l_out; + let mut im = unsafe { dev.alloc::<T>(dst_el) }.w()?; + + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let params = (dst_el, l_out, l_in, c_out, k_size, stride, col, &mut im); + let func = dev.get_or_load_func(&kernel_name::<T>("col2im1d"), kernels::CONV)?; + unsafe { func.launch(cfg, params) }.w()?; + Ok(im) + } +} + struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D); impl<'a> Map2 for ConvTranspose1D<'a> { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( @@ -1366,9 +1391,55 @@ impl BackendStorage for CudaStorage { kernel_l: &Layout, params: &crate::conv::ParamsConvTranspose1D, ) -> Result<Self> { + const USE_COL2IM_CONV1D_TR: bool = true; + let device = self.device().clone(); - let slice = - ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; + let can_use_col2im = kernel_l.is_contiguous() + && params.dilation == 1 + && params.padding == 0 + && params.output_padding == 0; + let slice = if USE_COL2IM_CONV1D_TR && can_use_col2im { + let (b_size, c_in, l_in) = l.shape().dims3()?; + let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?; + if !kernel_l.is_contiguous() { + crate::bail!( + "convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}" + ) + } + if c_in != c_in2 { + crate::bail!( + "convtr1d: shape mismatch on c_in {:?} {:?}", + l.shape(), + kernel_l.shape() + ) + } + let col = { + // This merges the last two dimensions of the kernel together. + let kernel_l_mm = Layout::new( + (b_size, c_in, k_size * c_out).into(), + vec![0, k_size * c_out, 1], + kernel_l.start_offset(), + ); + self.matmul( + kernel, + ( + b_size, + /* m */ l_in, + /* n */ c_out * k_size, + /* k */ c_in, + ), + &l.transpose(1, 2)?, + &kernel_l_mm, + )? + }; + let col_l = Layout::contiguous((b_size, l_in, c_out, k_size)); + Col2Im1D { + stride: params.stride, + } + .map(&col.slice, &device, &col_l)? + } else { + ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)? + }; Ok(Self { slice, device }) } |