diff options
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r-- | candle-core/src/cuda_backend.rs | 62 |
1 files changed, 57 insertions, 5 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 25d73f9b..b7756fa6 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1149,6 +1149,55 @@ impl<'a> Map2 for Conv2D<'a> { } } +struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D); +impl<'a> Map2 for ConvTranspose1D<'a> { + fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( + &self, + inp: &CudaSlice<T>, + inp_l: &Layout, + k: &CudaSlice<T>, + k_l: &Layout, + dev: &CudaDevice, + ) -> Result<CudaSlice<T>> { + // Kernel shape: (c_in_k, c_out, l_k) + // Input shape: (b_size, c_in, l_in) + let p = &self.0; + let l_out = p.l_out(); + let dst_el = p.c_out * l_out * p.b_size; + let inp = &inp.slice(inp_l.start_offset()..); + let k = &k.slice(k_l.start_offset()..); + let shape = inp_l.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::<T>(dst_el) }.w()?; + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose1d"), kernels::CONV)?; + let ds = if dims.len() == 3 { + [dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat() + } else { + crate::bail!("unexpected input shape for conv_transpose1d {dims:?}") + }; + let ds = dev.htod_copy(ds).w()?; + let params = ( + el, + l_out, + p.stride, + p.padding, + p.output_padding, + p.dilation, + &ds, + inp, + k, + &out, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } +} + struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D); impl<'a> Map2 for ConvTranspose2D<'a> { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( @@ -1810,12 +1859,15 @@ impl BackendStorage for CudaStorage { fn conv_transpose1d( &self, - _: &Layout, - _: &Self, - _: &Layout, - _: &crate::conv::ParamsConvTranspose1D, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConvTranspose1D, ) -> Result<Self> { - todo!() + let device = self.device().clone(); + let slice = + ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; + Ok(Self { slice, device }) } #[cfg(not(feature = "cudnn"))] |