diff options
author | Jeffrey Dallatezza <jeffreydallatezza@gmail.com> | 2024-04-29 02:21:53 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-29 11:21:53 +0200 |
commit | a0d03aded1b8c4cfe96f7d6490f5c709c31b76f0 (patch) | |
tree | d9c1020da5b3f9fb2faff97edfa60251c1c4952e /candle-nn | |
parent | 3bbb88fcb463a6bdbb0e71c7b2d211dd02681493 (diff) | |
download | candle-a0d03aded1b8c4cfe96f7d6490f5c709c31b76f0.tar.gz candle-a0d03aded1b8c4cfe96f7d6490f5c709c31b76f0.tar.bz2 candle-a0d03aded1b8c4cfe96f7d6490f5c709c31b76f0.zip |
Bug Fix: When converting a tensor to a variable, clone if the tensor is already a variable. (#2124)
* When converting a tensor to a variable, clone if the tensor is already a variable.
* Add a test to ensure training a batch norm works with VarMaps
---------
Co-authored-by: Jeffrey Dallatezza <jeffreydallatezza@Jeffreys-Laptop.local>
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/tests/batch_norm.rs | 46 |
1 files changed, 44 insertions, 2 deletions
diff --git a/candle-nn/tests/batch_norm.rs b/candle-nn/tests/batch_norm.rs index 6fd7361a..8ce49c92 100644 --- a/candle-nn/tests/batch_norm.rs +++ b/candle-nn/tests/batch_norm.rs @@ -6,7 +6,7 @@ extern crate accelerate_src; use anyhow::Result; use candle::{test_utils, DType, Device, Tensor}; -use candle_nn::BatchNorm; +use candle_nn::{batch_norm, BatchNorm, BatchNormConfig, VarBuilder, VarMap}; /* The test below has been generated using the following PyTorch code: import torch @@ -20,7 +20,7 @@ print(m.running_mean) print(m.running_var) */ #[test] -fn batch_norm() -> Result<()> { +fn batch_norm_test() -> Result<()> { let running_mean = Tensor::zeros(5, DType::F32, &Device::Cpu)?; let running_var = Tensor::ones(5, DType::F32, &Device::Cpu)?; let bn = BatchNorm::new_no_bias(5, running_mean.clone(), running_var.clone(), 1e-8)?; @@ -84,3 +84,45 @@ fn batch_norm() -> Result<()> { ); Ok(()) } + +// This test makes sure that we can train a batch norm layer using a VarMap. +#[test] +fn train_batch_norm() -> Result<()> { + let vm = VarMap::new(); + let vb = VarBuilder::from_varmap(&vm, DType::F32, &Device::Cpu); + let bn = batch_norm(1, BatchNormConfig::default(), vb)?; + // Get a copy of the original mean to ensure it is being updated. + let original_mean = bn.running_mean().detach().copy()?; + let var_map_mean = { + vm.data() + .lock() + .unwrap() + .get("running_mean") + .unwrap() + .clone() + }; + // Ensure the var map mean is the same as the running mean. + assert_eq!( + test_utils::to_vec1_round(bn.running_mean(), 4)?, + test_utils::to_vec1_round(var_map_mean.as_tensor(), 4)?, + ); + // Train with a something guaranteed to be different from the running mean. + let mean_plus_one = { + let one = original_mean.ones_like()?; + original_mean.add(&one)?.reshape((1, 1))? + }; + + bn.forward_train(&mean_plus_one)?; + // Assert that the running mean has been updated. + assert_ne!( + test_utils::to_vec1_round(bn.running_mean(), 4)?, + test_utils::to_vec1_round(&original_mean, 4)?, + ); + + // Assert that the var map mean has been updated. + assert_eq!( + test_utils::to_vec1_round(bn.running_mean(), 4)?, + test_utils::to_vec1_round(var_map_mean.as_tensor(), 4)?, + ); + Ok(()) +} |