diff options
Diffstat (limited to 'candle-core')
-rw-r--r-- | candle-core/src/backprop.rs | 12 | ||||
-rw-r--r-- | candle-core/src/op.rs | 15 | ||||
-rw-r--r-- | candle-core/src/storage.rs | 18 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 29 |
4 files changed, 74 insertions, 0 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 2dff0a5a..0eab508e 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -55,6 +55,11 @@ impl Tensor { kernel: rhs, .. } + | Op::Conv2D { + arg: lhs, + kernel: rhs, + .. + } | Op::CustomOp2(lhs, rhs, _) | Op::Binary(lhs, rhs, _) | Op::Gather(lhs, rhs, _) @@ -81,6 +86,8 @@ impl Tensor { } } Op::Reshape(node) + | Op::UpsampleNearest2D(node) + | Op::AvgPool2D { arg: node, .. } | Op::Copy(node) | Op::Broadcast(node) | Op::Cmp(node, _) @@ -163,6 +170,11 @@ impl Tensor { *f_sum_grad = f_sum_grad.add(&f_grad)?; } Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?, + Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?, + Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?, + Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported { + op: "upsample-nearest2d", + })?, Op::Gather(arg, indexes, dim) => { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?; diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index b4ebca51..aea8b733 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -80,6 +80,21 @@ pub enum Op { stride: usize, }, + #[allow(dead_code)] + Conv2D { + arg: Tensor, + kernel: Tensor, + padding: usize, + stride: usize, + }, + + AvgPool2D { + arg: Tensor, + kernel_size: (usize, usize), + stride: (usize, usize), + }, + UpsampleNearest2D(Tensor), + Cat(Vec<Tensor>, usize), #[allow(dead_code)] // add is currently unused. diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 1e1ef305..cbca4fc4 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -266,6 +266,24 @@ impl Storage { } } + pub(crate) fn avg_pool2d( + &self, + _layout: &Layout, + _kernel_size: (usize, usize), + _stride: (usize, usize), + ) -> Result<Self> { + todo!() + } + + pub(crate) fn upsample_nearest2d( + &self, + _layout: &Layout, + _h: usize, + _w: usize, + ) -> Result<Self> { + todo!() + } + pub(crate) fn where_cond( &self, layout: &Layout, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index f7bd894a..ffa4bf8c 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -817,6 +817,35 @@ impl Tensor { Ok(from_storage(storage, out_dims, op, false)) } + pub fn conv2d(&self, _kernel: &Self, _padding: usize, _stride: usize) -> Result<Self> { + todo!() + } + + pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> { + let (n, c, _h, _w) = self.dims4()?; + let op = BackpropOp::new1(self, Op::UpsampleNearest2D); + let storage = self + .storage() + .upsample_nearest2d(self.layout(), target_h, target_w)?; + Ok(from_storage(storage, (n, c, target_h, target_w), op, false)) + } + + pub fn avg_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> { + let (n, c, h, w) = self.dims4()?; + // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d + let h_out = (h - kernel_size.0) / stride.0 + 1; + let w_out = (w - kernel_size.1) / stride.1 + 1; + let op = BackpropOp::new1(self, |arg| Op::AvgPool2D { + arg, + kernel_size, + stride, + }); + let storage = self + .storage() + .avg_pool2d(self.layout(), kernel_size, stride)?; + Ok(from_storage(storage, (n, c, h_out, w_out), op, false)) + } + /// Returns the matrix-multiplication of the input tensor with the other provided tensor. /// /// # Arguments |