diff options
Diffstat (limited to 'candle-metal-kernels/src/tests.rs')
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 99 |
1 files changed, 70 insertions, 29 deletions
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index d0ca8330..067dece8 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -806,28 +806,43 @@ fn gemm() { ); } -fn run_random<T: Clone>(seed: u64, shape: &[usize], name: &'static str, min: f32, max: f32) -> Vec<T> { +fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b: f32) -> Vec<T> { let device = device(); let fence = device.new_fence(); let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged; - let length = shape.iter().product::<usize>(); let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options); - call_random_uniform( - &device, - command_buffer, - &kernels, - name, - seed, - min, - max, - length, - &output, - ) - .unwrap(); + if name.starts_with("rand_uniform") { + call_random_uniform( + &device, + command_buffer, + &kernels, + name, + seed, + a, + b, + length, + &output, + ) + .unwrap(); + } else { + call_random_normal( + &device, + command_buffer, + &kernels, + name, + seed, + a, + b, + length, + &output, + ) + .unwrap(); + } + command_buffer.commit(); command_buffer.wait_until_completed(); @@ -837,24 +852,50 @@ fn run_random<T: Clone>(seed: u64, shape: &[usize], name: &'static str, min: f32 #[test] fn random() { - use std::fs::File; - use std::io::prelude::*; - let shape = vec![1024, 4]; + fn calc_mean(data: &[f32]) -> f32 { + let sum = data.iter().sum::<f32>() as f32; + let count = data.len(); + assert!(count > 0); + sum / count as f32 + } + + fn calc_stddev(data: &[f32]) -> f32 { + let mean = calc_mean(data); + let count = data.len(); + assert!(count > 0); + + let variance = data.iter().map(|value| { + let diff = mean - (*value as f32); + diff * diff + }).sum::<f32>() / count as f32; + + variance.sqrt() + } + + let shape = vec![1024, 10]; + + let length = shape.iter().product::<usize>(); let seed = 299792458; + let min = -30.0; let max = 30.0; - let results = run_random::<f32>(seed, &shape, "rand_uniform_f32", min, max); - for &v in &results { - assert!(v >= min && v <= max); + let mean = 100.0; + let stddev = 50.0; + + macro_rules! validate_random { + ($type:ty) => { + let results: Vec<f32> = run_random::<$type>(concat!("rand_uniform_", stringify!($type)), seed, length, min, max).into_iter().map(f32::from).collect(); + results.iter().for_each(|v| assert!(*v >= min && *v <= max)); + assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0); + + let results: Vec<f32> = run_random::<$type>(concat!("rand_normal_", stringify!($type)), seed, length, mean, stddev).into_iter().map(f32::from).collect(); + assert!((calc_mean(&results) - mean).abs() < mean / 10.0); + assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0); + }; } - // Writing bytes to file for testing with ENT - // https://www.fourmilab.ch/random/ - // TODO: Remove before merge - let (head, body, tail) = unsafe { results.align_to::<u8>() }; - assert!(head.is_empty()); - assert!(tail.is_empty()); - let mut file = File::create("test").unwrap(); - file.write_all(body).unwrap(); -} + validate_random!(f32); + validate_random!(f16); + validate_random!(bf16); +}
\ No newline at end of file |