diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-06-28 10:04:51 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-28 10:04:51 +0100 |
commit | d0ff3b2d130c6474676a19000c81396fc8e6a2bf (patch) | |
tree | 7c4044496ff51867a6401377e5a02ec334ed3984 /candle-core/src | |
parent | 50eff0005b1e822a5514d37f5afe41a10bbb4b22 (diff) | |
parent | 615196e7be243e21c96707cfb543ff79de07e461 (diff) | |
download | candle-d0ff3b2d130c6474676a19000c81396fc8e6a2bf.tar.gz candle-d0ff3b2d130c6474676a19000c81396fc8e6a2bf.tar.bz2 candle-d0ff3b2d130c6474676a19000c81396fc8e6a2bf.zip |
Merge pull request #24 from LaurentMazare/more-grads
Support gradients for reshape and where_cond.
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/backprop.rs | 21 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 2 |
2 files changed, 17 insertions, 6 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index bc6740cf..ef15e65f 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -106,9 +106,8 @@ impl Tensor { } let grad = grads.remove(node).unwrap(); // TODO: We should perform all these operations in place (or at least not track the - // whole graph). - // The only drawback would be if we wanted to support grad of grad but this is out of - // scope. + // whole graph). The only drawback would be if we wanted to support grad of grad but + // this is out of scope. if let Some(op) = node.op() { match op { Op::Add(lhs, rhs) => { @@ -139,8 +138,14 @@ impl Tensor { let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } - Op::WhereCond(_pred, _t, _f) => { - return Err(Error::BackwardNotSupported { op: "where_cond" }) + Op::WhereCond(pred, t, f) => { + let zeros = grad.zeros_like()?; + let t_sum_grad = grads.or_insert(t)?; + let t_grad = pred.where_cond(&grad, &zeros)?; + *t_sum_grad = t_sum_grad.add(&t_grad)?; + let f_sum_grad = grads.or_insert(f)?; + let f_grad = pred.where_cond(&zeros, &grad)?; + *f_sum_grad = f_sum_grad.add(&f_grad)?; } Op::Embedding(_lhs, _rhs) => { return Err(Error::BackwardNotSupported { op: "embedding" }) @@ -209,7 +214,11 @@ impl Tensor { Op::Softmax(_arg, _) => { return Err(Error::BackwardNotSupported { op: "softmax" }) } - Op::Reshape(_arg) => return Err(Error::BackwardNotSupported { op: "reshape" }), + Op::Reshape(arg) => { + let arg_grad = grad.reshape(arg.dims())?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)? + } Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "gelu" }), Op::Relu(_) => return Err(Error::BackwardNotSupported { op: "relu" }), Op::Sqr(arg) => { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 4cff4efc..feb59d3c 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -121,6 +121,7 @@ fn from_storage<S: Into<Shape>>( } impl Tensor { + // TODO: Maybe this should be a broadcast rather than actually creating the full tensor. fn ones_impl<S: Into<Shape>>( shape: S, dtype: DType, @@ -144,6 +145,7 @@ impl Tensor { Tensor::ones(self.shape(), self.dtype(), &self.device()) } + // TODO: Maybe this should be a broadcast rather than actually creating the full tensor. fn zeros_impl<S: Into<Shape>>( shape: S, dtype: DType, |