summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorIvar Flakstad <69173633+ivarflakstad@users.noreply.github.com>2024-01-12 07:26:42 +0100
committerIvar Flakstad <69173633+ivarflakstad@users.noreply.github.com>2024-01-12 07:26:42 +0100
commite06e8d0dbea3a052195f4ca27fb5ddcdbf1cd30c (patch)
tree9158095aea69b5a9c50299358a3dbc4e5035a758 /candle-metal-kernels
parente63bb8661beb4ea139f4e7f1d85f56907d918b2b (diff)
downloadcandle-e06e8d0dbea3a052195f4ca27fb5ddcdbf1cd30c.tar.gz
candle-e06e8d0dbea3a052195f4ca27fb5ddcdbf1cd30c.tar.bz2
candle-e06e8d0dbea3a052195f4ca27fb5ddcdbf1cd30c.zip
fmt
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/tests.rs38
1 files changed, 29 insertions, 9 deletions
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index b15505f7..775ee0fa 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -975,7 +975,6 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b:
.unwrap();
}
-
command_buffer.commit();
command_buffer.wait_until_completed();
@@ -984,7 +983,6 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b:
#[test]
fn random() {
-
fn calc_mean(data: &[f32]) -> f32 {
let sum = data.iter().sum::<f32>() as f32;
let count = data.len();
@@ -997,10 +995,14 @@ fn random() {
let count = data.len();
assert!(count > 0);
- let variance = data.iter().map(|value| {
- let diff = mean - (*value as f32);
- diff * diff
- }).sum::<f32>() / count as f32;
+ let variance = data
+ .iter()
+ .map(|value| {
+ let diff = mean - (*value as f32);
+ diff * diff
+ })
+ .sum::<f32>()
+ / count as f32;
variance.sqrt()
}
@@ -1017,11 +1019,29 @@ fn random() {
macro_rules! validate_random {
($type:ty) => {
- let results: Vec<f32> = run_random::<$type>(concat!("rand_uniform_", stringify!($type)), seed, length, min, max).into_iter().map(f32::from).collect();
+ let results: Vec<f32> = run_random::<$type>(
+ concat!("rand_uniform_", stringify!($type)),
+ seed,
+ length,
+ min,
+ max,
+ )
+ .into_iter()
+ .map(f32::from)
+ .collect();
results.iter().for_each(|v| assert!(*v >= min && *v <= max));
assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
- let results: Vec<f32> = run_random::<$type>(concat!("rand_normal_", stringify!($type)), seed, length, mean, stddev).into_iter().map(f32::from).collect();
+ let results: Vec<f32> = run_random::<$type>(
+ concat!("rand_normal_", stringify!($type)),
+ seed,
+ length,
+ mean,
+ stddev,
+ )
+ .into_iter()
+ .map(f32::from)
+ .collect();
assert!((calc_mean(&results) - mean).abs() < mean / 10.0);
assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0);
};
@@ -1030,4 +1050,4 @@ fn random() {
validate_random!(f32);
validate_random!(f16);
validate_random!(bf16);
-} \ No newline at end of file
+}