diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-02-25 20:50:08 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-25 20:50:08 +0100 |
commit | 1a6043af5123bf9e189063d3baf110b39cf47617 (patch) | |
tree | 3400ac112e92d7d83a0b98a1c66ae046fbbf82df /candle-nn | |
parent | 2f22afd80ef6bc3e0ac7f6d55e4a4dc4dd480190 (diff) | |
download | candle-1a6043af5123bf9e189063d3baf110b39cf47617.tar.gz candle-1a6043af5123bf9e189063d3baf110b39cf47617.tar.bz2 candle-1a6043af5123bf9e189063d3baf110b39cf47617.zip |
Tweak the VarMap set type. (#1758)
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/var_map.rs | 2 | ||||
-rw-r--r-- | candle-nn/tests/optim.rs | 39 |
2 files changed, 39 insertions, 2 deletions
diff --git a/candle-nn/src/var_map.rs b/candle-nn/src/var_map.rs index d34cee78..3cb27c63 100644 --- a/candle-nn/src/var_map.rs +++ b/candle-nn/src/var_map.rs @@ -70,7 +70,7 @@ impl VarMap { /// /// If an error is returned, some of the variables might have already been set to their new /// values. - pub fn set<I: Iterator<Item = (K, V)>, K: AsRef<String>, V: AsRef<Tensor>>( + pub fn set<I: Iterator<Item = (K, V)>, K: AsRef<str>, V: AsRef<Tensor>>( &mut self, iter: I, ) -> Result<()> { diff --git a/candle-nn/tests/optim.rs b/candle-nn/tests/optim.rs index 841f65c8..4eb14ed8 100644 --- a/candle-nn/tests/optim.rs +++ b/candle-nn/tests/optim.rs @@ -7,7 +7,7 @@ extern crate accelerate_src; use candle::test_utils::{to_vec0_round, to_vec2_round}; use anyhow::Result; -use candle::{Device, Tensor, Var}; +use candle::{DType, Device, Tensor, Var}; use candle_nn::{AdamW, Linear, Module, Optimizer, ParamsAdamW, SGD}; #[test] @@ -121,3 +121,40 @@ fn adamw_linear_regression() -> Result<()> { assert_eq!(to_vec0_round(b.as_tensor(), 4)?, 0.7873); Ok(()) } + +#[test] +fn adamw_linear_regression_varmap() -> Result<()> { + use candle_nn::Init::Const; + + // Similar as the previous test but using a VarMap. + let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; + let b_gen = Tensor::new(-2f32, &Device::Cpu)?; + let gen = Linear::new(w_gen, Some(b_gen)); + let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; + let sample_ys = gen.forward(&sample_xs)?; + + let mut var_map = candle_nn::VarMap::new(); + + let w = var_map.get((1, 2), "w", Const(0.), DType::F32, &Device::Cpu)?; + let b = var_map.get((), "b", Const(0.), DType::F32, &Device::Cpu)?; + let params = ParamsAdamW { + lr: 0.1, + ..Default::default() + }; + let mut opt = AdamW::new(var_map.all_vars(), params)?; + let lin = Linear::new(w, Some(b)); + for _step in 0..100 { + let ys = lin.forward(&sample_xs)?; + let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; + opt.backward_step(&loss)?; + } + assert_eq!(to_vec2_round(lin.weight(), 4)?, &[[2.7257, 0.7097]]); + assert_eq!(to_vec0_round(lin.bias().unwrap(), 4)?, 0.7873); + + var_map.set([("w", Tensor::zeros((1, 2), DType::F32, &Device::Cpu)?)].into_iter())?; + var_map.set([("b", Tensor::ones((), DType::F32, &Device::Cpu)?)].into_iter())?; + + assert_eq!(to_vec2_round(lin.weight(), 4)?, &[[0., 0.]]); + assert_eq!(to_vec0_round(lin.bias().unwrap(), 4)?, 1.); + Ok(()) +} |