summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs30
1 files changed, 28 insertions, 2 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index ffa4bf8c..adba7376 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -817,8 +817,34 @@ impl Tensor {
Ok(from_storage(storage, out_dims, op, false))
}
- pub fn conv2d(&self, _kernel: &Self, _padding: usize, _stride: usize) -> Result<Self> {
- todo!()
+ pub fn conv2d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
+ let (b_size, c_in, i_h, i_w) = self.dims4()?;
+ let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
+ if c_in != c_in_k {
+ crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
+ }
+ let params = crate::conv::ParamsConv2D {
+ b_size,
+ i_h,
+ i_w,
+ k_h,
+ k_w,
+ c_out,
+ c_in,
+ padding,
+ stride,
+ };
+ let storage =
+ self.storage()
+ .conv2d(self.layout(), &kernel.storage(), kernel.layout(), &params)?;
+ let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
+ arg,
+ kernel,
+ padding,
+ stride,
+ });
+ let out_dims = params.out_dims();
+ Ok(from_storage(storage, out_dims, op, false))
}
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {