diff options
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r-- | candle-core/src/backprop.rs | 30 |
1 files changed, 26 insertions, 4 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 61a81be0..2711da85 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -179,11 +179,33 @@ impl Tensor { start_idx += len; } } - Op::Broadcast(_arg) => { - return Err(Error::BackwardNotSupported { op: "broadcast" }) + Op::Broadcast(arg) => { + let arg_dims = arg.dims(); + let node_dims = node.dims(); + // The number of dims that have been inserted on the left. + let left_dims = node_dims.len() - arg_dims.len(); + let mut sum_dims: Vec<usize> = (0..left_dims).collect(); + for (dim, (node_dim, arg_dim)) in node_dims[left_dims..] + .iter() + .zip(arg_dims.iter()) + .enumerate() + { + if node_dim != arg_dim { + sum_dims.push(dim + left_dims) + } + } + + let mut arg_grad = grad.sum(sum_dims.as_slice())?; + // sum_dims has increasing values. + for &dim in sum_dims.iter().rev() { + arg_grad = arg_grad.squeeze(dim)? + } + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.broadcast_add(&arg_grad)? } - Op::Sum(_arg, _sum_dims) => { - return Err(Error::BackwardNotSupported { op: "sum" }) + Op::Sum(arg, _sum_dims) => { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.broadcast_add(&grad)? } Op::ToDType(arg) => { let sum_grad = grads.or_insert(arg)?; |