diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-02 14:03:49 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-02 14:03:49 +0100 |
commit | 0902846f25ad35afd532853336f86fff2656e4c0 (patch) | |
tree | 92ca6c6da29c69cf78817d3150309efd825ca177 /candle-nn/tests | |
parent | e2acbe1e72a300bfb9df1a783cabb34d2cc53c5b (diff) | |
download | candle-0902846f25ad35afd532853336f86fff2656e4c0.tar.gz candle-0902846f25ad35afd532853336f86fff2656e4c0.tar.bz2 candle-0902846f25ad35afd532853336f86fff2656e4c0.zip |
Add the AdamW optimizer. (#307)
* Add the AdamW optimizer.
* Add some AdamW test validated against PyTorch.
Diffstat (limited to 'candle-nn/tests')
-rw-r--r-- | candle-nn/tests/ops.rs | 20 | ||||
-rw-r--r-- | candle-nn/tests/optim.rs | 56 | ||||
-rw-r--r-- | candle-nn/tests/test_utils.rs | 39 |
3 files changed, 100 insertions, 15 deletions
diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs index ca82dd1f..e4f0ff83 100644 --- a/candle-nn/tests/ops.rs +++ b/candle-nn/tests/ops.rs @@ -1,18 +1,10 @@ -use candle::{Device, Result, Tensor}; +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; -pub fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> { - let b = 10f32.powi(digits); - let t = t.to_vec3::<f32>()?; - let t = t - .iter() - .map(|t| { - t.iter() - .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) - .collect() - }) - .collect(); - Ok(t) -} +mod test_utils; +use test_utils::to_vec3_round; + +use candle::{Device, Result, Tensor}; #[test] fn softmax() -> Result<()> { diff --git a/candle-nn/tests/optim.rs b/candle-nn/tests/optim.rs index 54c378cc..1327ae91 100644 --- a/candle-nn/tests/optim.rs +++ b/candle-nn/tests/optim.rs @@ -1,9 +1,12 @@ #[cfg(feature = "mkl")] extern crate intel_mkl_src; +mod test_utils; +use test_utils::{to_vec0_round, to_vec2_round}; + use anyhow::Result; use candle::{Device, Tensor, Var}; -use candle_nn::{Linear, SGD}; +use candle_nn::{AdamW, Linear, ParamsAdamW, SGD}; #[test] fn sgd_optim() -> Result<()> { @@ -65,3 +68,54 @@ fn sgd_linear_regression() -> Result<()> { assert_eq!(b.to_scalar::<f32>()?, -1.9796902); Ok(()) } + +/* The following test returns the same values as the PyTorch code below. +import torch +from torch import optim + +w_gen = torch.tensor([[3., 1.]]) +b_gen = torch.tensor([-2.]) + +sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) +sample_ys = sample_xs.matmul(w_gen.t()) + b_gen + +m = torch.nn.Linear(2, 1) +with torch.no_grad(): + m.weight.zero_() + m.bias.zero_() +optimizer = optim.AdamW(m.parameters(), lr=0.1) +for _step in range(100): + optimizer.zero_grad() + ys = m(sample_xs) + loss = ((ys - sample_ys)**2).sum() + loss.backward() + optimizer.step() +print(m.weight) +print(m.bias) +*/ +#[test] +fn adamw_linear_regression() -> Result<()> { + let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; + let b_gen = Tensor::new(-2f32, &Device::Cpu)?; + let gen = Linear::new(w_gen, Some(b_gen)); + let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; + let sample_ys = gen.forward(&sample_xs)?; + + // Now use backprop to run a linear regression between samples and get the coefficients back. + let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; + let b = Var::new(0f32, &Device::Cpu)?; + let params = ParamsAdamW { + lr: 0.1, + ..Default::default() + }; + let mut opt = AdamW::new(vec![w.clone(), b.clone()], params)?; + let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); + for _step in 0..100 { + let ys = lin.forward(&sample_xs)?; + let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; + opt.backward_step(&loss)?; + } + assert_eq!(to_vec2_round(w.as_tensor(), 4)?, &[[2.7257, 0.7097]]); + assert_eq!(to_vec0_round(b.as_tensor(), 4)?, 0.7873); + Ok(()) +} diff --git a/candle-nn/tests/test_utils.rs b/candle-nn/tests/test_utils.rs new file mode 100644 index 00000000..bb422cd9 --- /dev/null +++ b/candle-nn/tests/test_utils.rs @@ -0,0 +1,39 @@ +#![allow(dead_code)] +use candle::{Result, Tensor}; + +pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> { + let b = 10f32.powi(digits); + let t = t.to_vec0::<f32>()?; + Ok(f32::round(t * b) / b) +} + +pub fn to_vec1_round(t: &Tensor, digits: i32) -> Result<Vec<f32>> { + let b = 10f32.powi(digits); + let t = t.to_vec1::<f32>()?; + let t = t.iter().map(|t| f32::round(t * b) / b).collect(); + Ok(t) +} + +pub fn to_vec2_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<f32>>> { + let b = 10f32.powi(digits); + let t = t.to_vec2::<f32>()?; + let t = t + .iter() + .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) + .collect(); + Ok(t) +} + +pub fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> { + let b = 10f32.powi(digits); + let t = t.to_vec3::<f32>()?; + let t = t + .iter() + .map(|t| { + t.iter() + .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) + .collect() + }) + .collect(); + Ok(t) +} |