summaryrefslogtreecommitdiff
path: root/candle-nn/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-13 20:14:10 +0100
committerGitHub <noreply@github.com>2023-07-13 20:14:10 +0100
commit23e105cd941bdd1aaa1ae45816a9910fb829e74f (patch)
tree85bf768748676f0cc0a223845d691279b544afb0 /candle-nn/src
parent3c02ea56b0dcc39c30ad0d41d942384cc28f65c2 (diff)
downloadcandle-23e105cd941bdd1aaa1ae45816a9910fb829e74f.tar.gz
candle-23e105cd941bdd1aaa1ae45816a9910fb829e74f.tar.bz2
candle-23e105cd941bdd1aaa1ae45816a9910fb829e74f.zip
Add the gradient for reduce-sum. (#162)
* Add the gradient for reduce-sum. * And add the gradient for the broadcast ops. * Add some backprop tests. * Add some linear regression example.
Diffstat (limited to 'candle-nn/src')
-rw-r--r--candle-nn/src/optim.rs2
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-nn/src/optim.rs b/candle-nn/src/optim.rs
index 741c51dc..a8b5b370 100644
--- a/candle-nn/src/optim.rs
+++ b/candle-nn/src/optim.rs
@@ -39,7 +39,7 @@ impl SGD {
let grads = loss.backward()?;
for var in self.vars.iter() {
if let Some(grad) = grads.get(var) {
- var.set(&var.sub(&(grad * self.learning_rate)?)?)?
+ var.set(&var.sub(&(grad * self.learning_rate)?)?)?;
}
}
Ok(())