summaryrefslogtreecommitdiff
path: root/candle-core/src/backprop.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r--candle-core/src/backprop.rs30
1 files changed, 26 insertions, 4 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 61a81be0..2711da85 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -179,11 +179,33 @@ impl Tensor {
start_idx += len;
}
}
- Op::Broadcast(_arg) => {
- return Err(Error::BackwardNotSupported { op: "broadcast" })
+ Op::Broadcast(arg) => {
+ let arg_dims = arg.dims();
+ let node_dims = node.dims();
+ // The number of dims that have been inserted on the left.
+ let left_dims = node_dims.len() - arg_dims.len();
+ let mut sum_dims: Vec<usize> = (0..left_dims).collect();
+ for (dim, (node_dim, arg_dim)) in node_dims[left_dims..]
+ .iter()
+ .zip(arg_dims.iter())
+ .enumerate()
+ {
+ if node_dim != arg_dim {
+ sum_dims.push(dim + left_dims)
+ }
+ }
+
+ let mut arg_grad = grad.sum(sum_dims.as_slice())?;
+ // sum_dims has increasing values.
+ for &dim in sum_dims.iter().rev() {
+ arg_grad = arg_grad.squeeze(dim)?
+ }
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.broadcast_add(&arg_grad)?
}
- Op::Sum(_arg, _sum_dims) => {
- return Err(Error::BackwardNotSupported { op: "sum" })
+ Op::Sum(arg, _sum_dims) => {
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.broadcast_add(&grad)?
}
Op::ToDType(arg) => {
let sum_grad = grads.or_insert(arg)?;