summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-29 10:17:59 +0100
committerGitHub <noreply@github.com>2023-08-29 10:17:59 +0100
commitd0a330448d7c7dad242f2a6bafca29b8f53dc119 (patch)
tree0a3c194a321fd08eaf2a4e8ae1d846012deda72c
parent4b8d57ba15471f8f321e89a0114bffb97fe4b618 (diff)
downloadcandle-d0a330448d7c7dad242f2a6bafca29b8f53dc119.tar.gz
candle-d0a330448d7c7dad242f2a6bafca29b8f53dc119.tar.bz2
candle-d0a330448d7c7dad242f2a6bafca29b8f53dc119.zip
Backprop support for pooling ops. (#652)
* Backprop support for pooling ops. * max-pool gradient.
-rw-r--r--candle-core/src/backprop.rs37
1 files changed, 35 insertions, 2 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 4366f3b6..9ecdee4f 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -219,8 +219,41 @@ impl Tensor {
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose2d",
})?,
- Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?,
- Op::MaxPool2D { .. } => Err(Error::BackwardNotSupported { op: "max-pool2d" })?,
+ Op::AvgPool2D {
+ arg,
+ kernel_size,
+ stride,
+ } => {
+ if kernel_size != stride {
+ crate::bail!("backward not supported for avgpool2d if ksize {kernel_size:?} != stride {stride:?}")
+ }
+ let (_n, _c, h, w) = arg.dims4()?;
+ let grad_arg = grad.upsample_nearest2d(h, w)?;
+ let grad_arg =
+ (grad_arg * (1f64 / (kernel_size.0 * kernel_size.1) as f64))?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&grad_arg)?;
+ }
+ Op::MaxPool2D {
+ arg,
+ kernel_size,
+ stride,
+ } => {
+ if kernel_size != stride {
+ crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}")
+ }
+ let (_n, _c, h, w) = arg.dims4()?;
+ // For computing the max-pool gradient, we compute a mask where a 1 means
+ // that the element is the maximum, then we apply this mask to the
+ // upsampled gradient (taking into account that multiple max may exist so
+ // we scale the gradient for this case).
+ let node_upsampled = node.upsample_nearest2d(h, w)?;
+ let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?;
+ let avg = mask.avg_pool2d(*kernel_size, *stride)?;
+ let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&grad_arg)?;
+ }
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest2d",
})?,