diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-07 17:15:38 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-07 16:15:38 +0100 |
commit | 2345b8ce3f8ebab6e04d6ea25f7c809efb037995 (patch) | |
tree | a1c74ed8d29d1f14d329eab6e1900749b041bbdd /candle-core/src/tensor.rs | |
parent | f53a333ea91233b41dd946c2c30213c79b4d1cb3 (diff) | |
download | candle-2345b8ce3f8ebab6e04d6ea25f7c809efb037995.tar.gz candle-2345b8ce3f8ebab6e04d6ea25f7c809efb037995.tar.bz2 candle-2345b8ce3f8ebab6e04d6ea25f7c809efb037995.zip |
Skeleton for the avg-pool2d and upsample-nearest2d ops. (#337)
* Skeleton for the avg-pool2d and upsample-nearest2d ops.
* Preliminary conv2d support.
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index f7bd894a..ffa4bf8c 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -817,6 +817,35 @@ impl Tensor { Ok(from_storage(storage, out_dims, op, false)) } + pub fn conv2d(&self, _kernel: &Self, _padding: usize, _stride: usize) -> Result<Self> { + todo!() + } + + pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> { + let (n, c, _h, _w) = self.dims4()?; + let op = BackpropOp::new1(self, Op::UpsampleNearest2D); + let storage = self + .storage() + .upsample_nearest2d(self.layout(), target_h, target_w)?; + Ok(from_storage(storage, (n, c, target_h, target_w), op, false)) + } + + pub fn avg_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> { + let (n, c, h, w) = self.dims4()?; + // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d + let h_out = (h - kernel_size.0) / stride.0 + 1; + let w_out = (w - kernel_size.1) / stride.1 + 1; + let op = BackpropOp::new1(self, |arg| Op::AvgPool2D { + arg, + kernel_size, + stride, + }); + let storage = self + .storage() + .avg_pool2d(self.layout(), kernel_size, stride)?; + Ok(from_storage(storage, (n, c, h_out, w_out), op, false)) + } + /// Returns the matrix-multiplication of the input tensor with the other provided tensor. /// /// # Arguments |