summaryrefslogtreecommitdiff
path: root/candle-core/src/cuda_backend.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r--candle-core/src/cuda_backend.rs62
1 files changed, 57 insertions, 5 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 25d73f9b..b7756fa6 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -1149,6 +1149,55 @@ impl<'a> Map2 for Conv2D<'a> {
}
}
+struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
+impl<'a> Map2 for ConvTranspose1D<'a> {
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
+ &self,
+ inp: &CudaSlice<T>,
+ inp_l: &Layout,
+ k: &CudaSlice<T>,
+ k_l: &Layout,
+ dev: &CudaDevice,
+ ) -> Result<CudaSlice<T>> {
+ // Kernel shape: (c_in_k, c_out, l_k)
+ // Input shape: (b_size, c_in, l_in)
+ let p = &self.0;
+ let l_out = p.l_out();
+ let dst_el = p.c_out * l_out * p.b_size;
+ let inp = &inp.slice(inp_l.start_offset()..);
+ let k = &k.slice(k_l.start_offset()..);
+ let shape = inp_l.shape();
+ let dims = shape.dims();
+ let el = shape.elem_count();
+
+ // SAFETY: Set later by running the kernel.
+ let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
+ let cfg = LaunchConfig::for_num_elems(dst_el as u32);
+ let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose1d"), kernels::CONV)?;
+ let ds = if dims.len() == 3 {
+ [dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
+ } else {
+ crate::bail!("unexpected input shape for conv_transpose1d {dims:?}")
+ };
+ let ds = dev.htod_copy(ds).w()?;
+ let params = (
+ el,
+ l_out,
+ p.stride,
+ p.padding,
+ p.output_padding,
+ p.dilation,
+ &ds,
+ inp,
+ k,
+ &out,
+ );
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(out)
+ }
+}
+
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
impl<'a> Map2 for ConvTranspose2D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@@ -1810,12 +1859,15 @@ impl BackendStorage for CudaStorage {
fn conv_transpose1d(
&self,
- _: &Layout,
- _: &Self,
- _: &Layout,
- _: &crate::conv::ParamsConvTranspose1D,
+ l: &Layout,
+ kernel: &Self,
+ kernel_l: &Layout,
+ params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
- todo!()
+ let device = self.device().clone();
+ let slice =
+ ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
+ Ok(Self { slice, device })
}
#[cfg(not(feature = "cudnn"))]