diff options
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r-- | candle-core/src/backprop.rs | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 2dff0a5a..0eab508e 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -55,6 +55,11 @@ impl Tensor { kernel: rhs, .. } + | Op::Conv2D { + arg: lhs, + kernel: rhs, + .. + } | Op::CustomOp2(lhs, rhs, _) | Op::Binary(lhs, rhs, _) | Op::Gather(lhs, rhs, _) @@ -81,6 +86,8 @@ impl Tensor { } } Op::Reshape(node) + | Op::UpsampleNearest2D(node) + | Op::AvgPool2D { arg: node, .. } | Op::Copy(node) | Op::Broadcast(node) | Op::Cmp(node, _) @@ -163,6 +170,11 @@ impl Tensor { *f_sum_grad = f_sum_grad.add(&f_grad)?; } Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?, + Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?, + Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?, + Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported { + op: "upsample-nearest2d", + })?, Op::Gather(arg, indexes, dim) => { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?; |