diff options
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 30 |
1 files changed, 28 insertions, 2 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index ffa4bf8c..adba7376 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -817,8 +817,34 @@ impl Tensor { Ok(from_storage(storage, out_dims, op, false)) } - pub fn conv2d(&self, _kernel: &Self, _padding: usize, _stride: usize) -> Result<Self> { - todo!() + pub fn conv2d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> { + let (b_size, c_in, i_h, i_w) = self.dims4()?; + let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?; + if c_in != c_in_k { + crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})") + } + let params = crate::conv::ParamsConv2D { + b_size, + i_h, + i_w, + k_h, + k_w, + c_out, + c_in, + padding, + stride, + }; + let storage = + self.storage() + .conv2d(self.layout(), &kernel.storage(), kernel.layout(), ¶ms)?; + let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D { + arg, + kernel, + padding, + stride, + }); + let out_dims = params.out_dims(); + Ok(from_storage(storage, out_dims, op, false)) } pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> { |