summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/lib.rs2
-rw-r--r--candle-nn/src/lib.rs2
-rw-r--r--candle-nn/src/optim.rs116
-rw-r--r--candle-nn/tests/ops.rs20
-rw-r--r--candle-nn/tests/optim.rs56
-rw-r--r--candle-nn/tests/test_utils.rs39
6 files changed, 216 insertions, 19 deletions
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index 4fa0c5d3..c374d245 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -34,7 +34,7 @@
//! Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers)
pub mod backend;
-mod backprop;
+pub mod backprop;
mod conv;
mod convert;
pub mod cpu_backend;
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index 45edfc46..46a83800 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -19,5 +19,5 @@ pub use embedding::{embedding, Embedding};
pub use init::Init;
pub use layer_norm::{layer_norm, LayerNorm};
pub use linear::{linear, linear_no_bias, Linear};
-pub use optim::SGD;
+pub use optim::{AdamW, ParamsAdamW, SGD};
pub use var_builder::{VarBuilder, VarMap};
diff --git a/candle-nn/src/optim.rs b/candle-nn/src/optim.rs
index d20ef284..39f7b34e 100644
--- a/candle-nn/src/optim.rs
+++ b/candle-nn/src/optim.rs
@@ -1,6 +1,9 @@
//! Various optimization algorithms.
use candle::{Result, Tensor, Var};
+/// Optimizer for Stochastic Gradient Descent.
+///
+/// Contrary to the PyTorch implementation of SGD, this version does not support momentum.
#[derive(Debug)]
pub struct SGD {
vars: Vec<Var>,
@@ -42,8 +45,7 @@ impl SGD {
self.vars.push(var.clone())
}
- pub fn backward_step(&self, loss: &Tensor) -> Result<()> {
- let grads = loss.backward()?;
+ pub fn step(&self, grads: &candle::backprop::GradStore) -> Result<()> {
for var in self.vars.iter() {
if let Some(grad) = grads.get(var) {
var.set(&var.sub(&(grad * self.learning_rate)?)?)?;
@@ -51,4 +53,114 @@ impl SGD {
}
Ok(())
}
+
+ pub fn backward_step(&self, loss: &Tensor) -> Result<()> {
+ let grads = loss.backward()?;
+ self.step(&grads)
+ }
+}
+
+#[derive(Clone, Debug)]
+pub struct ParamsAdamW {
+ pub lr: f64,
+ pub beta1: f64,
+ pub beta2: f64,
+ pub eps: f64,
+ pub weight_decay: f64,
+}
+
+impl Default for ParamsAdamW {
+ fn default() -> Self {
+ Self {
+ lr: 0.001,
+ beta1: 0.9,
+ beta2: 0.999,
+ eps: 1e-8,
+ weight_decay: 0.01,
+ }
+ }
+}
+
+#[derive(Debug)]
+struct VarAdamW {
+ var: Var,
+ first_moment: Var,
+ second_moment: Var,
+}
+
+#[derive(Debug)]
+pub struct AdamW {
+ vars: Vec<VarAdamW>,
+ step_t: usize,
+ params: ParamsAdamW,
+}
+
+impl AdamW {
+ pub fn new(vars: Vec<Var>, params: ParamsAdamW) -> Result<Self> {
+ let vars = vars
+ .into_iter()
+ .map(|var| {
+ let dtype = var.dtype();
+ let shape = var.shape();
+ let device = var.device();
+ let first_moment = Var::zeros(shape, dtype, device)?;
+ let second_moment = Var::zeros(shape, dtype, device)?;
+ Ok(VarAdamW {
+ var,
+ first_moment,
+ second_moment,
+ })
+ })
+ .collect::<Result<Vec<_>>>()?;
+ Ok(Self {
+ vars,
+ params,
+ step_t: 0,
+ })
+ }
+
+ pub fn new_lr(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
+ let params = ParamsAdamW {
+ lr: learning_rate,
+ ..ParamsAdamW::default()
+ };
+ Self::new(vars, params)
+ }
+
+ pub fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
+ self.step_t += 1;
+ let lr = self.params.lr;
+ let lambda = self.params.weight_decay;
+ let lr_lambda = lr * lambda;
+ let beta1 = self.params.beta1;
+ let beta2 = self.params.beta2;
+ let scale_m = 1f64 / (1f64 - beta1.powi(self.step_t as i32));
+ let scale_v = 1f64 / (1f64 - beta2.powi(self.step_t as i32));
+ for var in self.vars.iter() {
+ let theta = &var.var;
+ let m = &var.first_moment;
+ let v = &var.second_moment;
+ if let Some(g) = grads.get(theta) {
+ // This involves locking 3 RWLocks per params, if the parameters are large this
+ // should not be an issue but this may be problematic with models with lots of
+ // small parameters.
+ let next_m = ((m.as_tensor() * beta1)? + (g * (1.0 - beta1))?)?;
+ let next_v = ((v.as_tensor() * beta2)? + (g.sqr()? * (1.0 - beta2))?)?;
+ let m_hat = (&next_m * scale_m)?;
+ let v_hat = (&next_v * scale_v)?;
+ let next_theta = (theta.as_tensor() * (1f64 - lr_lambda))?;
+ let adjusted_grad = (m_hat / (v_hat.sqrt()? + self.params.eps)?)?;
+ let next_theta = (next_theta - (adjusted_grad * lr)?)?;
+ m.set(&next_m)?;
+ v.set(&next_v)?;
+ theta.set(&next_theta)?;
+ }
+ }
+ Ok(())
+ }
+
+ pub fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
+ let grads = loss.backward()?;
+ self.step(&grads)
+ }
}
diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs
index ca82dd1f..e4f0ff83 100644
--- a/candle-nn/tests/ops.rs
+++ b/candle-nn/tests/ops.rs
@@ -1,18 +1,10 @@
-use candle::{Device, Result, Tensor};
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
-pub fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
- let b = 10f32.powi(digits);
- let t = t.to_vec3::<f32>()?;
- let t = t
- .iter()
- .map(|t| {
- t.iter()
- .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
- .collect()
- })
- .collect();
- Ok(t)
-}
+mod test_utils;
+use test_utils::to_vec3_round;
+
+use candle::{Device, Result, Tensor};
#[test]
fn softmax() -> Result<()> {
diff --git a/candle-nn/tests/optim.rs b/candle-nn/tests/optim.rs
index 54c378cc..1327ae91 100644
--- a/candle-nn/tests/optim.rs
+++ b/candle-nn/tests/optim.rs
@@ -1,9 +1,12 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
+mod test_utils;
+use test_utils::{to_vec0_round, to_vec2_round};
+
use anyhow::Result;
use candle::{Device, Tensor, Var};
-use candle_nn::{Linear, SGD};
+use candle_nn::{AdamW, Linear, ParamsAdamW, SGD};
#[test]
fn sgd_optim() -> Result<()> {
@@ -65,3 +68,54 @@ fn sgd_linear_regression() -> Result<()> {
assert_eq!(b.to_scalar::<f32>()?, -1.9796902);
Ok(())
}
+
+/* The following test returns the same values as the PyTorch code below.
+import torch
+from torch import optim
+
+w_gen = torch.tensor([[3., 1.]])
+b_gen = torch.tensor([-2.])
+
+sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
+sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
+
+m = torch.nn.Linear(2, 1)
+with torch.no_grad():
+ m.weight.zero_()
+ m.bias.zero_()
+optimizer = optim.AdamW(m.parameters(), lr=0.1)
+for _step in range(100):
+ optimizer.zero_grad()
+ ys = m(sample_xs)
+ loss = ((ys - sample_ys)**2).sum()
+ loss.backward()
+ optimizer.step()
+print(m.weight)
+print(m.bias)
+*/
+#[test]
+fn adamw_linear_regression() -> Result<()> {
+ let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
+ let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
+ let gen = Linear::new(w_gen, Some(b_gen));
+ let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
+ let sample_ys = gen.forward(&sample_xs)?;
+
+ // Now use backprop to run a linear regression between samples and get the coefficients back.
+ let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
+ let b = Var::new(0f32, &Device::Cpu)?;
+ let params = ParamsAdamW {
+ lr: 0.1,
+ ..Default::default()
+ };
+ let mut opt = AdamW::new(vec![w.clone(), b.clone()], params)?;
+ let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
+ for _step in 0..100 {
+ let ys = lin.forward(&sample_xs)?;
+ let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
+ opt.backward_step(&loss)?;
+ }
+ assert_eq!(to_vec2_round(w.as_tensor(), 4)?, &[[2.7257, 0.7097]]);
+ assert_eq!(to_vec0_round(b.as_tensor(), 4)?, 0.7873);
+ Ok(())
+}
diff --git a/candle-nn/tests/test_utils.rs b/candle-nn/tests/test_utils.rs
new file mode 100644
index 00000000..bb422cd9
--- /dev/null
+++ b/candle-nn/tests/test_utils.rs
@@ -0,0 +1,39 @@
+#![allow(dead_code)]
+use candle::{Result, Tensor};
+
+pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
+ let b = 10f32.powi(digits);
+ let t = t.to_vec0::<f32>()?;
+ Ok(f32::round(t * b) / b)
+}
+
+pub fn to_vec1_round(t: &Tensor, digits: i32) -> Result<Vec<f32>> {
+ let b = 10f32.powi(digits);
+ let t = t.to_vec1::<f32>()?;
+ let t = t.iter().map(|t| f32::round(t * b) / b).collect();
+ Ok(t)
+}
+
+pub fn to_vec2_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<f32>>> {
+ let b = 10f32.powi(digits);
+ let t = t.to_vec2::<f32>()?;
+ let t = t
+ .iter()
+ .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
+ .collect();
+ Ok(t)
+}
+
+pub fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
+ let b = 10f32.powi(digits);
+ let t = t.to_vec3::<f32>()?;
+ let t = t
+ .iter()
+ .map(|t| {
+ t.iter()
+ .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
+ .collect()
+ })
+ .collect();
+ Ok(t)
+}