summaryrefslogtreecommitdiff
path: root/candle-core/src/backprop.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r--candle-core/src/backprop.rs26
1 files changed, 22 insertions, 4 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index fc0c79a2..c152f31f 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -114,7 +114,7 @@ impl Tensor {
| Op::Unary(_node, UnaryOp::Round) => nodes,
Op::Reshape(node)
| Op::UpsampleNearest1D(node)
- | Op::UpsampleNearest2D(node)
+ | Op::UpsampleNearest2D { arg: node, .. }
| Op::AvgPool2D { arg: node, .. }
| Op::MaxPool2D { arg: node, .. }
| Op::Copy(node)
@@ -350,9 +350,27 @@ impl Tensor {
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest1d",
})?,
- Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
- op: "upsample-nearest2d",
- })?,
+ Op::UpsampleNearest2D {
+ arg,
+ target_h,
+ target_w,
+ } => {
+ let (_n, c, h, w) = arg.dims4()?;
+ if target_h % h != 0 || target_w % w != 0 {
+ crate::bail!("backward not supported for non integer upscaling factors")
+ }
+ let scale_h = target_h / h;
+ let scale_w = target_w / w;
+
+ if scale_h != scale_w {
+ crate::bail!("backward not supported for non uniform upscaling factors")
+ };
+ let kernel =
+ Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?;
+ let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = conv_sum;
+ }
Op::SliceScatter0(lhs, rhs, start_rhs) => {
let rhs_sum_grad = grads.or_insert(rhs)?;
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;