summaryrefslogtreecommitdiff
path: root/candle-core/src/backprop.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r--candle-core/src/backprop.rs12
1 files changed, 12 insertions, 0 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 2dff0a5a..0eab508e 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -55,6 +55,11 @@ impl Tensor {
kernel: rhs,
..
}
+ | Op::Conv2D {
+ arg: lhs,
+ kernel: rhs,
+ ..
+ }
| Op::CustomOp2(lhs, rhs, _)
| Op::Binary(lhs, rhs, _)
| Op::Gather(lhs, rhs, _)
@@ -81,6 +86,8 @@ impl Tensor {
}
}
Op::Reshape(node)
+ | Op::UpsampleNearest2D(node)
+ | Op::AvgPool2D { arg: node, .. }
| Op::Copy(node)
| Op::Broadcast(node)
| Op::Cmp(node, _)
@@ -163,6 +170,11 @@ impl Tensor {
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
+ Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?,
+ Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?,
+ Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
+ op: "upsample-nearest2d",
+ })?,
Op::Gather(arg, indexes, dim) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;