diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-13 20:14:10 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-13 20:14:10 +0100 |
commit | 23e105cd941bdd1aaa1ae45816a9910fb829e74f (patch) | |
tree | 85bf768748676f0cc0a223845d691279b544afb0 /candle-nn/src | |
parent | 3c02ea56b0dcc39c30ad0d41d942384cc28f65c2 (diff) | |
download | candle-23e105cd941bdd1aaa1ae45816a9910fb829e74f.tar.gz candle-23e105cd941bdd1aaa1ae45816a9910fb829e74f.tar.bz2 candle-23e105cd941bdd1aaa1ae45816a9910fb829e74f.zip |
Add the gradient for reduce-sum. (#162)
* Add the gradient for reduce-sum.
* And add the gradient for the broadcast ops.
* Add some backprop tests.
* Add some linear regression example.
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/optim.rs | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-nn/src/optim.rs b/candle-nn/src/optim.rs index 741c51dc..a8b5b370 100644 --- a/candle-nn/src/optim.rs +++ b/candle-nn/src/optim.rs @@ -39,7 +39,7 @@ impl SGD { let grads = loss.backward()?; for var in self.vars.iter() { if let Some(grad) = grads.get(var) { - var.set(&var.sub(&(grad * self.learning_rate)?)?)? + var.set(&var.sub(&(grad * self.learning_rate)?)?)?; } } Ok(()) |