summaryrefslogtreecommitdiff
path: root/candle-nn/tests
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-02 14:03:49 +0100
committerGitHub <noreply@github.com>2023-08-02 14:03:49 +0100
commit0902846f25ad35afd532853336f86fff2656e4c0 (patch)
tree92ca6c6da29c69cf78817d3150309efd825ca177 /candle-nn/tests
parente2acbe1e72a300bfb9df1a783cabb34d2cc53c5b (diff)
downloadcandle-0902846f25ad35afd532853336f86fff2656e4c0.tar.gz
candle-0902846f25ad35afd532853336f86fff2656e4c0.tar.bz2
candle-0902846f25ad35afd532853336f86fff2656e4c0.zip
Add the AdamW optimizer. (#307)
* Add the AdamW optimizer. * Add some AdamW test validated against PyTorch.
Diffstat (limited to 'candle-nn/tests')
-rw-r--r--candle-nn/tests/ops.rs20
-rw-r--r--candle-nn/tests/optim.rs56
-rw-r--r--candle-nn/tests/test_utils.rs39
3 files changed, 100 insertions, 15 deletions
diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs
index ca82dd1f..e4f0ff83 100644
--- a/candle-nn/tests/ops.rs
+++ b/candle-nn/tests/ops.rs
@@ -1,18 +1,10 @@
-use candle::{Device, Result, Tensor};
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
-pub fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
- let b = 10f32.powi(digits);
- let t = t.to_vec3::<f32>()?;
- let t = t
- .iter()
- .map(|t| {
- t.iter()
- .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
- .collect()
- })
- .collect();
- Ok(t)
-}
+mod test_utils;
+use test_utils::to_vec3_round;
+
+use candle::{Device, Result, Tensor};
#[test]
fn softmax() -> Result<()> {
diff --git a/candle-nn/tests/optim.rs b/candle-nn/tests/optim.rs
index 54c378cc..1327ae91 100644
--- a/candle-nn/tests/optim.rs
+++ b/candle-nn/tests/optim.rs
@@ -1,9 +1,12 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
+mod test_utils;
+use test_utils::{to_vec0_round, to_vec2_round};
+
use anyhow::Result;
use candle::{Device, Tensor, Var};
-use candle_nn::{Linear, SGD};
+use candle_nn::{AdamW, Linear, ParamsAdamW, SGD};
#[test]
fn sgd_optim() -> Result<()> {
@@ -65,3 +68,54 @@ fn sgd_linear_regression() -> Result<()> {
assert_eq!(b.to_scalar::<f32>()?, -1.9796902);
Ok(())
}
+
+/* The following test returns the same values as the PyTorch code below.
+import torch
+from torch import optim
+
+w_gen = torch.tensor([[3., 1.]])
+b_gen = torch.tensor([-2.])
+
+sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
+sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
+
+m = torch.nn.Linear(2, 1)
+with torch.no_grad():
+ m.weight.zero_()
+ m.bias.zero_()
+optimizer = optim.AdamW(m.parameters(), lr=0.1)
+for _step in range(100):
+ optimizer.zero_grad()
+ ys = m(sample_xs)
+ loss = ((ys - sample_ys)**2).sum()
+ loss.backward()
+ optimizer.step()
+print(m.weight)
+print(m.bias)
+*/
+#[test]
+fn adamw_linear_regression() -> Result<()> {
+ let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
+ let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
+ let gen = Linear::new(w_gen, Some(b_gen));
+ let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
+ let sample_ys = gen.forward(&sample_xs)?;
+
+ // Now use backprop to run a linear regression between samples and get the coefficients back.
+ let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
+ let b = Var::new(0f32, &Device::Cpu)?;
+ let params = ParamsAdamW {
+ lr: 0.1,
+ ..Default::default()
+ };
+ let mut opt = AdamW::new(vec![w.clone(), b.clone()], params)?;
+ let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
+ for _step in 0..100 {
+ let ys = lin.forward(&sample_xs)?;
+ let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
+ opt.backward_step(&loss)?;
+ }
+ assert_eq!(to_vec2_round(w.as_tensor(), 4)?, &[[2.7257, 0.7097]]);
+ assert_eq!(to_vec0_round(b.as_tensor(), 4)?, 0.7873);
+ Ok(())
+}
diff --git a/candle-nn/tests/test_utils.rs b/candle-nn/tests/test_utils.rs
new file mode 100644
index 00000000..bb422cd9
--- /dev/null
+++ b/candle-nn/tests/test_utils.rs
@@ -0,0 +1,39 @@
+#![allow(dead_code)]
+use candle::{Result, Tensor};
+
+pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
+ let b = 10f32.powi(digits);
+ let t = t.to_vec0::<f32>()?;
+ Ok(f32::round(t * b) / b)
+}
+
+pub fn to_vec1_round(t: &Tensor, digits: i32) -> Result<Vec<f32>> {
+ let b = 10f32.powi(digits);
+ let t = t.to_vec1::<f32>()?;
+ let t = t.iter().map(|t| f32::round(t * b) / b).collect();
+ Ok(t)
+}
+
+pub fn to_vec2_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<f32>>> {
+ let b = 10f32.powi(digits);
+ let t = t.to_vec2::<f32>()?;
+ let t = t
+ .iter()
+ .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
+ .collect();
+ Ok(t)
+}
+
+pub fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
+ let b = 10f32.powi(digits);
+ let t = t.to_vec3::<f32>()?;
+ let t = t
+ .iter()
+ .map(|t| {
+ t.iter()
+ .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
+ .collect()
+ })
+ .collect();
+ Ok(t)
+}