summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/tests.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/tests.rs')
-rw-r--r--candle-metal-kernels/src/tests.rs99
1 files changed, 70 insertions, 29 deletions
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index d0ca8330..067dece8 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -806,28 +806,43 @@ fn gemm() {
);
}
-fn run_random<T: Clone>(seed: u64, shape: &[usize], name: &'static str, min: f32, max: f32) -> Vec<T> {
+fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b: f32) -> Vec<T> {
let device = device();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
- let length = shape.iter().product::<usize>();
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
- call_random_uniform(
- &device,
- command_buffer,
- &kernels,
- name,
- seed,
- min,
- max,
- length,
- &output,
- )
- .unwrap();
+ if name.starts_with("rand_uniform") {
+ call_random_uniform(
+ &device,
+ command_buffer,
+ &kernels,
+ name,
+ seed,
+ a,
+ b,
+ length,
+ &output,
+ )
+ .unwrap();
+ } else {
+ call_random_normal(
+ &device,
+ command_buffer,
+ &kernels,
+ name,
+ seed,
+ a,
+ b,
+ length,
+ &output,
+ )
+ .unwrap();
+ }
+
command_buffer.commit();
command_buffer.wait_until_completed();
@@ -837,24 +852,50 @@ fn run_random<T: Clone>(seed: u64, shape: &[usize], name: &'static str, min: f32
#[test]
fn random() {
- use std::fs::File;
- use std::io::prelude::*;
- let shape = vec![1024, 4];
+ fn calc_mean(data: &[f32]) -> f32 {
+ let sum = data.iter().sum::<f32>() as f32;
+ let count = data.len();
+ assert!(count > 0);
+ sum / count as f32
+ }
+
+ fn calc_stddev(data: &[f32]) -> f32 {
+ let mean = calc_mean(data);
+ let count = data.len();
+ assert!(count > 0);
+
+ let variance = data.iter().map(|value| {
+ let diff = mean - (*value as f32);
+ diff * diff
+ }).sum::<f32>() / count as f32;
+
+ variance.sqrt()
+ }
+
+ let shape = vec![1024, 10];
+
+ let length = shape.iter().product::<usize>();
let seed = 299792458;
+
let min = -30.0;
let max = 30.0;
- let results = run_random::<f32>(seed, &shape, "rand_uniform_f32", min, max);
- for &v in &results {
- assert!(v >= min && v <= max);
+ let mean = 100.0;
+ let stddev = 50.0;
+
+ macro_rules! validate_random {
+ ($type:ty) => {
+ let results: Vec<f32> = run_random::<$type>(concat!("rand_uniform_", stringify!($type)), seed, length, min, max).into_iter().map(f32::from).collect();
+ results.iter().for_each(|v| assert!(*v >= min && *v <= max));
+ assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
+
+ let results: Vec<f32> = run_random::<$type>(concat!("rand_normal_", stringify!($type)), seed, length, mean, stddev).into_iter().map(f32::from).collect();
+ assert!((calc_mean(&results) - mean).abs() < mean / 10.0);
+ assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0);
+ };
}
- // Writing bytes to file for testing with ENT
- // https://www.fourmilab.ch/random/
- // TODO: Remove before merge
- let (head, body, tail) = unsafe { results.align_to::<u8>() };
- assert!(head.is_empty());
- assert!(tail.is_empty());
- let mut file = File::create("test").unwrap();
- file.write_all(body).unwrap();
-}
+ validate_random!(f32);
+ validate_random!(f16);
+ validate_random!(bf16);
+} \ No newline at end of file