summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-06-28 11:46:13 +0100
committerGitHub <noreply@github.com>2023-06-28 11:46:13 +0100
commitd461d9d751df67f262f108300dbc4e433d6062f5 (patch)
tree89c0b1dd27d226781638978158192ae88638a81b /candle-core/src
parent2998ff6ef7e4926bb91d4caffac92661c3241b68 (diff)
parent666d6dbcac760d40be157fa1bbd9643c5f085cb1 (diff)
downloadcandle-d461d9d751df67f262f108300dbc4e433d6062f5.tar.gz
candle-d461d9d751df67f262f108300dbc4e433d6062f5.tar.bz2
candle-d461d9d751df67f262f108300dbc4e433d6062f5.zip
Merge pull request #26 from LaurentMazare/narrow-grad
Add the grad for narrow.
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/backprop.rs27
1 files changed, 25 insertions, 2 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index ef15e65f..7801b878 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -208,8 +208,31 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&grad)?
}
- Op::Narrow(_arg, _, _, _) => {
- return Err(Error::BackwardNotSupported { op: "narrow" })
+ &Op::Narrow(ref arg, dim, start_idx, len) => {
+ let arg_dims = arg.dims();
+ let left_pad = if start_idx == 0 {
+ None
+ } else {
+ let mut dims = arg_dims.to_vec();
+ dims[dim] = start_idx;
+ Some(Tensor::zeros(dims, grad.dtype(), &grad.device())?)
+ };
+ let right_pad = arg_dims[dim] - start_idx - len;
+ let right_pad = if right_pad == 0 {
+ None
+ } else {
+ let mut dims = arg_dims.to_vec();
+ dims[dim] = right_pad;
+ Some(Tensor::zeros(dims, grad.dtype(), &grad.device())?)
+ };
+ let arg_grad = match (left_pad, right_pad) {
+ (None, None) => grad,
+ (Some(l), None) => Tensor::cat(&[&l, &grad], dim)?,
+ (None, Some(r)) => Tensor::cat(&[&grad, &r], dim)?,
+ (Some(l), Some(r)) => Tensor::cat(&[&l, &grad, &r], dim)?,
+ };
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&arg_grad)?
}
Op::Softmax(_arg, _) => {
return Err(Error::BackwardNotSupported { op: "softmax" })