summaryrefslogtreecommitdiff
path: root/candle-core/src/cuda_backend/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/cuda_backend/mod.rs')
-rw-r--r--candle-core/src/cuda_backend/mod.rs75
1 files changed, 73 insertions, 2 deletions
diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs
index 1ea9beaf..88f325f4 100644
--- a/candle-core/src/cuda_backend/mod.rs
+++ b/candle-core/src/cuda_backend/mod.rs
@@ -630,6 +630,31 @@ impl<'a> Map2 for Conv2D<'a> {
}
}
+struct Col2Im1D {
+ stride: usize,
+}
+
+impl Map1 for Col2Im1D {
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
+ &self,
+ col: &CudaSlice<T>,
+ dev: &CudaDevice,
+ l: &Layout,
+ ) -> Result<CudaSlice<T>> {
+ let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
+ let stride = self.stride;
+ let l_out = (l_in - 1) * stride + k_size;
+ let dst_el = b_size * c_out * l_out;
+ let mut im = unsafe { dev.alloc::<T>(dst_el) }.w()?;
+
+ let cfg = LaunchConfig::for_num_elems(dst_el as u32);
+ let params = (dst_el, l_out, l_in, c_out, k_size, stride, col, &mut im);
+ let func = dev.get_or_load_func(&kernel_name::<T>("col2im1d"), kernels::CONV)?;
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(im)
+ }
+}
+
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
impl<'a> Map2 for ConvTranspose1D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@@ -1366,9 +1391,55 @@ impl BackendStorage for CudaStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
+ const USE_COL2IM_CONV1D_TR: bool = true;
+
let device = self.device().clone();
- let slice =
- ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
+ let can_use_col2im = kernel_l.is_contiguous()
+ && params.dilation == 1
+ && params.padding == 0
+ && params.output_padding == 0;
+ let slice = if USE_COL2IM_CONV1D_TR && can_use_col2im {
+ let (b_size, c_in, l_in) = l.shape().dims3()?;
+ let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
+ if !kernel_l.is_contiguous() {
+ crate::bail!(
+ "convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
+ )
+ }
+ if c_in != c_in2 {
+ crate::bail!(
+ "convtr1d: shape mismatch on c_in {:?} {:?}",
+ l.shape(),
+ kernel_l.shape()
+ )
+ }
+ let col = {
+ // This merges the last two dimensions of the kernel together.
+ let kernel_l_mm = Layout::new(
+ (b_size, c_in, k_size * c_out).into(),
+ vec![0, k_size * c_out, 1],
+ kernel_l.start_offset(),
+ );
+ self.matmul(
+ kernel,
+ (
+ b_size,
+ /* m */ l_in,
+ /* n */ c_out * k_size,
+ /* k */ c_in,
+ ),
+ &l.transpose(1, 2)?,
+ &kernel_l_mm,
+ )?
+ };
+ let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
+ Col2Im1D {
+ stride: params.stride,
+ }
+ .map(&col.slice, &device, &col_l)?
+ } else {
+ ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?
+ };
Ok(Self { slice, device })
}