summaryrefslogtreecommitdiff
path: root/candle-core/src/backprop.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r--candle-core/src/backprop.rs2
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index f4f90373..c6d55e61 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -256,7 +256,7 @@ impl Tensor {
// we scale the gradient for this case).
let node_upsampled = node.upsample_nearest2d(h, w)?;
let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?;
- let avg = mask.avg_pool2d(*kernel_size, *stride)?;
+ let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?;
let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;