diff options
Diffstat (limited to 'candle-core/src/backprop.rs')
| -rw-r--r-- | candle-core/src/backprop.rs | 26 |
1 files changed, 22 insertions, 4 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index fc0c79a2..c152f31f 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -114,7 +114,7 @@ impl Tensor { | Op::Unary(_node, UnaryOp::Round) => nodes, Op::Reshape(node) | Op::UpsampleNearest1D(node) - | Op::UpsampleNearest2D(node) + | Op::UpsampleNearest2D { arg: node, .. } | Op::AvgPool2D { arg: node, .. } | Op::MaxPool2D { arg: node, .. } | Op::Copy(node) @@ -350,9 +350,27 @@ impl Tensor { Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported { op: "upsample-nearest1d", })?, - Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported { - op: "upsample-nearest2d", - })?, + Op::UpsampleNearest2D { + arg, + target_h, + target_w, + } => { + let (_n, c, h, w) = arg.dims4()?; + if target_h % h != 0 || target_w % w != 0 { + crate::bail!("backward not supported for non integer upscaling factors") + } + let scale_h = target_h / h; + let scale_w = target_w / w; + + if scale_h != scale_w { + crate::bail!("backward not supported for non uniform upscaling factors") + }; + let kernel = + Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?; + let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = conv_sum; + } Op::SliceScatter0(lhs, rhs, start_rhs) => { let rhs_sum_grad = grads.or_insert(rhs)?; let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?; |
