diff options
author | laurent <laurent.mazare@gmail.com> | 2023-06-28 10:46:00 +0100 |
---|---|---|
committer | laurent <laurent.mazare@gmail.com> | 2023-06-28 10:46:00 +0100 |
commit | 7938d2b84888cf14f7a719eb4b35e680097fc5c6 (patch) | |
tree | f564748ad45a494a5f3a8c0798d62eb6ada6e599 | |
parent | d0ff3b2d130c6474676a19000c81396fc8e6a2bf (diff) | |
download | candle-7938d2b84888cf14f7a719eb4b35e680097fc5c6.tar.gz candle-7938d2b84888cf14f7a719eb4b35e680097fc5c6.tar.bz2 candle-7938d2b84888cf14f7a719eb4b35e680097fc5c6.zip |
Add the grad for narrow.
-rw-r--r-- | candle-core/src/backprop.rs | 27 |
1 files changed, 25 insertions, 2 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index ef15e65f..7801b878 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -208,8 +208,31 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.sub(&grad)? } - Op::Narrow(_arg, _, _, _) => { - return Err(Error::BackwardNotSupported { op: "narrow" }) + &Op::Narrow(ref arg, dim, start_idx, len) => { + let arg_dims = arg.dims(); + let left_pad = if start_idx == 0 { + None + } else { + let mut dims = arg_dims.to_vec(); + dims[dim] = start_idx; + Some(Tensor::zeros(dims, grad.dtype(), &grad.device())?) + }; + let right_pad = arg_dims[dim] - start_idx - len; + let right_pad = if right_pad == 0 { + None + } else { + let mut dims = arg_dims.to_vec(); + dims[dim] = right_pad; + Some(Tensor::zeros(dims, grad.dtype(), &grad.device())?) + }; + let arg_grad = match (left_pad, right_pad) { + (None, None) => grad, + (Some(l), None) => Tensor::cat(&[&l, &grad], dim)?, + (None, Some(r)) => Tensor::cat(&[&grad, &r], dim)?, + (Some(l), Some(r)) => Tensor::cat(&[&l, &grad, &r], dim)?, + }; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)? } Op::Softmax(_arg, _) => { return Err(Error::BackwardNotSupported { op: "softmax" }) |