summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-08 07:04:32 +0200
committerGitHub <noreply@github.com>2023-08-08 06:04:32 +0100
commitb5bb5e056d838ad23a95e3feaf464bdca2b677cd (patch)
treefdb3dda8d7143e1221cd5e844960f63250ede877 /candle-core/src/tensor.rs
parentd0d7010682a8fb0678842b1711dfc45e2269ebf5 (diff)
downloadcandle-b5bb5e056d838ad23a95e3feaf464bdca2b677cd.tar.gz
candle-b5bb5e056d838ad23a95e3feaf464bdca2b677cd.tar.bz2
candle-b5bb5e056d838ad23a95e3feaf464bdca2b677cd.zip
Add more conv2d support. (#340)
* Add more conv2d support. * Conv2d cpu work. * Conv2d output shape.
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs30
1 files changed, 28 insertions, 2 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index ffa4bf8c..adba7376 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -817,8 +817,34 @@ impl Tensor {
Ok(from_storage(storage, out_dims, op, false))
}
- pub fn conv2d(&self, _kernel: &Self, _padding: usize, _stride: usize) -> Result<Self> {
- todo!()
+ pub fn conv2d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
+ let (b_size, c_in, i_h, i_w) = self.dims4()?;
+ let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
+ if c_in != c_in_k {
+ crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
+ }
+ let params = crate::conv::ParamsConv2D {
+ b_size,
+ i_h,
+ i_w,
+ k_h,
+ k_w,
+ c_out,
+ c_in,
+ padding,
+ stride,
+ };
+ let storage =
+ self.storage()
+ .conv2d(self.layout(), &kernel.storage(), kernel.layout(), &params)?;
+ let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
+ arg,
+ kernel,
+ padding,
+ stride,
+ });
+ let out_dims = params.out_dims();
+ Ok(from_storage(storage, out_dims, op, false))
}
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {