diff options
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r-- | candle-core/src/cuda_backend.rs | 15 |
1 files changed, 13 insertions, 2 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index ed696368..cd06e8d7 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -960,7 +960,9 @@ impl<'a> Map2 for Conv1D<'a> { crate::bail!("unexpected input shape for conv1d {dims:?}") }; let ds = dev.htod_copy(ds).w()?; - let params = (el, l_out, p.stride, p.padding, &ds, inp, k, &out); + let params = ( + el, l_out, p.stride, p.padding, p.dilation, &ds, inp, k, &out, + ); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) @@ -998,7 +1000,9 @@ impl<'a> Map2 for Conv2D<'a> { crate::bail!("unexpected input shape for conv2d {dims:?}") }; let ds = dev.htod_copy(ds).w()?; - let params = (el, out_w, out_h, p.stride, p.padding, &ds, inp, k, &out); + let params = ( + el, out_w, out_h, p.stride, p.padding, p.dilation, &ds, inp, k, &out, + ); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) @@ -1018,6 +1022,12 @@ impl<'a> Map2 for ConvTranspose2D<'a> { // Kernel shape: (c_in_k, c_out, h_k, w_k) // Input shape: (b_size, c_in, h_in, w_in) let p = &self.0; + if p.dilation != 1 { + crate::bail!( + "dilation {} is not supported for conv-transpose2d", + p.dilation + ) + } let (out_w, out_h) = (p.out_w(), p.out_h()); let dst_el = p.c_out * out_w * out_h * p.b_size; let inp = &inp.slice(inp_l.start_offset()..); @@ -1043,6 +1053,7 @@ impl<'a> Map2 for ConvTranspose2D<'a> { p.stride, p.padding, p.output_padding, + p.dilation, &ds, inp, k, |