summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorJeffrey Dallatezza <jeffreydallatezza@gmail.com>2024-04-29 02:21:53 -0700
committerGitHub <noreply@github.com>2024-04-29 11:21:53 +0200
commita0d03aded1b8c4cfe96f7d6490f5c709c31b76f0 (patch)
treed9c1020da5b3f9fb2faff97edfa60251c1c4952e /candle-nn
parent3bbb88fcb463a6bdbb0e71c7b2d211dd02681493 (diff)
downloadcandle-a0d03aded1b8c4cfe96f7d6490f5c709c31b76f0.tar.gz
candle-a0d03aded1b8c4cfe96f7d6490f5c709c31b76f0.tar.bz2
candle-a0d03aded1b8c4cfe96f7d6490f5c709c31b76f0.zip
Bug Fix: When converting a tensor to a variable, clone if the tensor is already a variable. (#2124)
* When converting a tensor to a variable, clone if the tensor is already a variable. * Add a test to ensure training a batch norm works with VarMaps --------- Co-authored-by: Jeffrey Dallatezza <jeffreydallatezza@Jeffreys-Laptop.local>
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/tests/batch_norm.rs46
1 files changed, 44 insertions, 2 deletions
diff --git a/candle-nn/tests/batch_norm.rs b/candle-nn/tests/batch_norm.rs
index 6fd7361a..8ce49c92 100644
--- a/candle-nn/tests/batch_norm.rs
+++ b/candle-nn/tests/batch_norm.rs
@@ -6,7 +6,7 @@ extern crate accelerate_src;
use anyhow::Result;
use candle::{test_utils, DType, Device, Tensor};
-use candle_nn::BatchNorm;
+use candle_nn::{batch_norm, BatchNorm, BatchNormConfig, VarBuilder, VarMap};
/* The test below has been generated using the following PyTorch code:
import torch
@@ -20,7 +20,7 @@ print(m.running_mean)
print(m.running_var)
*/
#[test]
-fn batch_norm() -> Result<()> {
+fn batch_norm_test() -> Result<()> {
let running_mean = Tensor::zeros(5, DType::F32, &Device::Cpu)?;
let running_var = Tensor::ones(5, DType::F32, &Device::Cpu)?;
let bn = BatchNorm::new_no_bias(5, running_mean.clone(), running_var.clone(), 1e-8)?;
@@ -84,3 +84,45 @@ fn batch_norm() -> Result<()> {
);
Ok(())
}
+
+// This test makes sure that we can train a batch norm layer using a VarMap.
+#[test]
+fn train_batch_norm() -> Result<()> {
+ let vm = VarMap::new();
+ let vb = VarBuilder::from_varmap(&vm, DType::F32, &Device::Cpu);
+ let bn = batch_norm(1, BatchNormConfig::default(), vb)?;
+ // Get a copy of the original mean to ensure it is being updated.
+ let original_mean = bn.running_mean().detach().copy()?;
+ let var_map_mean = {
+ vm.data()
+ .lock()
+ .unwrap()
+ .get("running_mean")
+ .unwrap()
+ .clone()
+ };
+ // Ensure the var map mean is the same as the running mean.
+ assert_eq!(
+ test_utils::to_vec1_round(bn.running_mean(), 4)?,
+ test_utils::to_vec1_round(var_map_mean.as_tensor(), 4)?,
+ );
+ // Train with a something guaranteed to be different from the running mean.
+ let mean_plus_one = {
+ let one = original_mean.ones_like()?;
+ original_mean.add(&one)?.reshape((1, 1))?
+ };
+
+ bn.forward_train(&mean_plus_one)?;
+ // Assert that the running mean has been updated.
+ assert_ne!(
+ test_utils::to_vec1_round(bn.running_mean(), 4)?,
+ test_utils::to_vec1_round(&original_mean, 4)?,
+ );
+
+ // Assert that the var map mean has been updated.
+ assert_eq!(
+ test_utils::to_vec1_round(bn.running_mean(), 4)?,
+ test_utils::to_vec1_round(var_map_mean.as_tensor(), 4)?,
+ );
+ Ok(())
+}