diff options
author | Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> | 2024-01-12 07:26:42 +0100 |
---|---|---|
committer | Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> | 2024-01-12 07:26:42 +0100 |
commit | e06e8d0dbea3a052195f4ca27fb5ddcdbf1cd30c (patch) | |
tree | 9158095aea69b5a9c50299358a3dbc4e5035a758 /candle-metal-kernels | |
parent | e63bb8661beb4ea139f4e7f1d85f56907d918b2b (diff) | |
download | candle-e06e8d0dbea3a052195f4ca27fb5ddcdbf1cd30c.tar.gz candle-e06e8d0dbea3a052195f4ca27fb5ddcdbf1cd30c.tar.bz2 candle-e06e8d0dbea3a052195f4ca27fb5ddcdbf1cd30c.zip |
fmt
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 38 |
1 files changed, 29 insertions, 9 deletions
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index b15505f7..775ee0fa 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -975,7 +975,6 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b: .unwrap(); } - command_buffer.commit(); command_buffer.wait_until_completed(); @@ -984,7 +983,6 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b: #[test] fn random() { - fn calc_mean(data: &[f32]) -> f32 { let sum = data.iter().sum::<f32>() as f32; let count = data.len(); @@ -997,10 +995,14 @@ fn random() { let count = data.len(); assert!(count > 0); - let variance = data.iter().map(|value| { - let diff = mean - (*value as f32); - diff * diff - }).sum::<f32>() / count as f32; + let variance = data + .iter() + .map(|value| { + let diff = mean - (*value as f32); + diff * diff + }) + .sum::<f32>() + / count as f32; variance.sqrt() } @@ -1017,11 +1019,29 @@ fn random() { macro_rules! validate_random { ($type:ty) => { - let results: Vec<f32> = run_random::<$type>(concat!("rand_uniform_", stringify!($type)), seed, length, min, max).into_iter().map(f32::from).collect(); + let results: Vec<f32> = run_random::<$type>( + concat!("rand_uniform_", stringify!($type)), + seed, + length, + min, + max, + ) + .into_iter() + .map(f32::from) + .collect(); results.iter().for_each(|v| assert!(*v >= min && *v <= max)); assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0); - let results: Vec<f32> = run_random::<$type>(concat!("rand_normal_", stringify!($type)), seed, length, mean, stddev).into_iter().map(f32::from).collect(); + let results: Vec<f32> = run_random::<$type>( + concat!("rand_normal_", stringify!($type)), + seed, + length, + mean, + stddev, + ) + .into_iter() + .map(f32::from) + .collect(); assert!((calc_mean(&results) - mean).abs() < mean / 10.0); assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0); }; @@ -1030,4 +1050,4 @@ fn random() { validate_random!(f32); validate_random!(f16); validate_random!(bf16); -}
\ No newline at end of file +} |