diff options
author | Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> | 2024-01-05 21:18:12 +0100 |
---|---|---|
committer | Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> | 2024-01-07 11:39:46 +0100 |
commit | 6bf52b9fdf82ad775611e82924d73172660a605e (patch) | |
tree | 00b898aeb6c097d9c996053b5485056196214cb8 /candle-metal-kernels | |
parent | 955e63c8033af247c51b7ada1ab2c12fa7170cf5 (diff) | |
download | candle-6bf52b9fdf82ad775611e82924d73172660a605e.tar.gz candle-6bf52b9fdf82ad775611e82924d73172660a605e.tar.bz2 candle-6bf52b9fdf82ad775611e82924d73172660a605e.zip |
Gaussian normal distribution of PRNG via Box-Muller transform
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 58 | ||||
-rw-r--r-- | candle-metal-kernels/src/random.metal | 107 | ||||
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 99 |
3 files changed, 178 insertions, 86 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 04442c8a..e2603b3b 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1415,7 +1415,6 @@ pub fn call_gemm( height: 1, depth: 1, }; - // println!("grid size {grid_size:?} group size {group_size:?}"); encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); @@ -1588,44 +1587,47 @@ pub fn call_random_uniform( "min must be less than max".to_string(), )); } + let pipeline = kernels.load_pipeline(device, Source::Random, name)?; + let encoder = command_buffer.new_compute_command_encoder(); - let size: usize = match name { - "rand_uniform_f32" => 4, - "rand_uniform_f16" | "rand_uniform_bf16" => 2, - _ => Err(MetalKernelError::LoadLibraryError(format!( - "{name} is not a valid kernel for random" - )))?, - }; + let odd = (length % 2 != 0) as usize; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd); - let elems_per_key = length; - let bytes_per_key = size * elems_per_key; + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); - let out_per_key = (bytes_per_key + 4 - 1) / 4; - let half_size = out_per_key / 2; - let odd = length % 2 != 0; + set_params!(encoder, (length, seed, min, max, buffer)); + encoder.use_resource(buffer, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_random_normal( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + seed: u64, + mean: f32, + stddev: f32, + length: usize, + buffer: &Buffer, +) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Random, name)?; let encoder = command_buffer.new_compute_command_encoder(); - let thread_group_count = MTLSize { - width: length as u64, - height: half_size as u64 + odd as u64, - depth: 1, - }; - let threads = std::cmp::min( - (half_size + odd as usize) as NSUInteger, - pipeline.max_total_threads_per_threadgroup(), - ); - let thread_group_size = MTLSize { - width: threads, - height: 1, - depth: 1, - }; + let odd = (length % 2 != 0) as usize; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, seed, min, max, buffer)); + set_params!(encoder, (length, seed, mean, stddev, buffer)); encoder.use_resource(buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); diff --git a/candle-metal-kernels/src/random.metal b/candle-metal-kernels/src/random.metal index 1604123d..5369e8e2 100644 --- a/candle-metal-kernels/src/random.metal +++ b/candle-metal-kernels/src/random.metal @@ -33,29 +33,34 @@ static constexpr constant uint64_t PHI[16] = { // Combined Tausworthe and LCG Random Number Generator. // https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-37-efficient-random-number-generation-and-application // https://indico.cern.ch/event/93877/contributions/2118070/attachments/1104200/1575343/acat3_revised_final.pdf -class HybridTaus { -private: - thread float seed; +struct HybridTaus { + + float state; + + HybridTaus() thread = default; + HybridTaus() threadgroup = default; + HybridTaus() device = default; + HybridTaus() constant = default; // Generate seeds for each thread. - thread uint4 seed_per_thread(const ulong4 seeds) { + METAL_FUNC static uint4 seed_per_thread(const ulong4 seeds) { return uint4(ulong4(seeds) * ulong4(PHI[0], PHI[1], PHI[2], PHI[3]) * ulong4(1099087573UL)); } // Tausworthe generator. - thread uint taus(const uint z, const int3 s, const uint M) { + METAL_FUNC static uint taus(const uint z, const int3 s, const uint M) { uint b = (((z << s.x) ^ z) >> s.y); return (((z & M) << s.z) ^ b); } // LCG generator. - thread uint lcg(const uint z) { + METAL_FUNC static uint lcg(const uint z) { return (1664525 * z + 1013904223UL); } -public: - thread HybridTaus(const ulong4 seeds) { - uint4 seed = this->seed_per_thread(seeds); + // Initialize the RNG state. + METAL_FUNC static HybridTaus init(const ulong4 seeds) { + uint4 seed = seed_per_thread(seeds); // Seed #1 uint z1 = taus(seed.x, S1, 4294967294UL); @@ -84,52 +89,96 @@ public: z3 = taus(r1, S3, 429496280UL); z4 = lcg(r1); - this->seed = (z1^z2^z3^z4) * UNIF01_INV32; + HybridTaus rng; + rng.state = (z1^z2^z3^z4) * UNIF01_INV32; + return rng; } - thread float rand() { - uint seed = this->seed * UNIF01_NORM32; + METAL_FUNC float rand() { + uint seed = this->state * UNIF01_NORM32; uint z1 = taus(seed, S1, 429496729UL); uint z2 = taus(seed, S2, 4294967288UL); uint z3 = taus(seed, S3, 429496280UL); uint z4 = lcg(seed); - thread float old_seed = this->seed; - this->seed = (z1^z2^z3^z4) * UNIF01_INV32; - return old_seed; + thread float result = this->state; + this->state = (z1^z2^z3^z4) * UNIF01_INV32; + return result; } }; template<typename T> METAL_FUNC void rand_uniform( - constant size_t &elem_count, + constant size_t &size, constant ulong &seed, constant float &min, constant float &max, device T *out, uint tid [[thread_position_in_grid]] ) { - if (tid >= elem_count) { + if (tid >= size) { return; } float diff = max - min; - HybridTaus rng = HybridTaus({seed, tid, 1, 1}); + HybridTaus rng = HybridTaus::init({seed, tid, 1, 1}); out[tid] = static_cast<T>(rng.rand() * diff + min); + out[size - tid] = static_cast<T>(rng.rand() * diff + min); } -#define UNIFORM_OP(NAME, T) \ -kernel void rand_uniform_##NAME( \ - constant size_t &elem_count, \ - constant ulong &seed, \ - constant float &min, \ - constant float &max, \ - device T *out, \ - uint tid [[thread_position_in_grid]] \ -) { \ - rand_uniform<T>(elem_count, seed, min, max, out, tid); \ -} \ +// Create Gaussian normal distribution using Box-Muller transform: +// https://en.wikipedia.org/wiki/Box–Muller_transform +template<typename T> METAL_FUNC void normal( + constant size_t &size, + constant ulong &seed, + constant float &mean, + constant float &stddev, + device T *out, + uint tid [[thread_position_in_grid]] +) { + if (tid >= size) { + return; + } + HybridTaus rng = HybridTaus::init({seed, tid, 1, 1}); + float u1 = rng.rand(); + float u2 = rng.rand(); + + float cosval; + float sinval = sincos(u1 * TWO_PI, cosval); + float mag = stddev * sqrt(-2.0 * log(u1)); + float z0 = mag * cosval + mean; + float z1 = mag * sinval + mean; + + out[tid] = static_cast<T>(z0); + out[size - tid] = static_cast<T>(z1); +} + +#define UNIFORM_OP(NAME, T) \ +kernel void rand_uniform_##NAME( \ + constant size_t &size, \ + constant ulong &seed, \ + constant float &min, \ + constant float &max, \ + device T *out, \ + uint tid [[thread_position_in_grid]] \ +) { \ + rand_uniform<T>(size, seed, min, max, out, tid); \ +} \ + +#define NORMAL_OP(NAME, T) \ +kernel void rand_normal_##NAME( \ + constant size_t &size, \ + constant ulong &seed, \ + constant float &mean, \ + constant float &stddev, \ + device T *out, \ + uint tid [[thread_position_in_grid]] \ +) { \ + normal<T>(size, seed, mean, stddev, out, tid); \ +} \ + #define RANDOM_OPS(NAME, T) \ UNIFORM_OP(NAME, T) \ +NORMAL_OP(NAME, T) \ RANDOM_OPS(f32, float) RANDOM_OPS(f16, half) diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index d0ca8330..067dece8 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -806,28 +806,43 @@ fn gemm() { ); } -fn run_random<T: Clone>(seed: u64, shape: &[usize], name: &'static str, min: f32, max: f32) -> Vec<T> { +fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b: f32) -> Vec<T> { let device = device(); let fence = device.new_fence(); let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged; - let length = shape.iter().product::<usize>(); let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options); - call_random_uniform( - &device, - command_buffer, - &kernels, - name, - seed, - min, - max, - length, - &output, - ) - .unwrap(); + if name.starts_with("rand_uniform") { + call_random_uniform( + &device, + command_buffer, + &kernels, + name, + seed, + a, + b, + length, + &output, + ) + .unwrap(); + } else { + call_random_normal( + &device, + command_buffer, + &kernels, + name, + seed, + a, + b, + length, + &output, + ) + .unwrap(); + } + command_buffer.commit(); command_buffer.wait_until_completed(); @@ -837,24 +852,50 @@ fn run_random<T: Clone>(seed: u64, shape: &[usize], name: &'static str, min: f32 #[test] fn random() { - use std::fs::File; - use std::io::prelude::*; - let shape = vec![1024, 4]; + fn calc_mean(data: &[f32]) -> f32 { + let sum = data.iter().sum::<f32>() as f32; + let count = data.len(); + assert!(count > 0); + sum / count as f32 + } + + fn calc_stddev(data: &[f32]) -> f32 { + let mean = calc_mean(data); + let count = data.len(); + assert!(count > 0); + + let variance = data.iter().map(|value| { + let diff = mean - (*value as f32); + diff * diff + }).sum::<f32>() / count as f32; + + variance.sqrt() + } + + let shape = vec![1024, 10]; + + let length = shape.iter().product::<usize>(); let seed = 299792458; + let min = -30.0; let max = 30.0; - let results = run_random::<f32>(seed, &shape, "rand_uniform_f32", min, max); - for &v in &results { - assert!(v >= min && v <= max); + let mean = 100.0; + let stddev = 50.0; + + macro_rules! validate_random { + ($type:ty) => { + let results: Vec<f32> = run_random::<$type>(concat!("rand_uniform_", stringify!($type)), seed, length, min, max).into_iter().map(f32::from).collect(); + results.iter().for_each(|v| assert!(*v >= min && *v <= max)); + assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0); + + let results: Vec<f32> = run_random::<$type>(concat!("rand_normal_", stringify!($type)), seed, length, mean, stddev).into_iter().map(f32::from).collect(); + assert!((calc_mean(&results) - mean).abs() < mean / 10.0); + assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0); + }; } - // Writing bytes to file for testing with ENT - // https://www.fourmilab.ch/random/ - // TODO: Remove before merge - let (head, body, tail) = unsafe { results.align_to::<u8>() }; - assert!(head.is_empty()); - assert!(tail.is_empty()); - let mut file = File::create("test").unwrap(); - file.write_all(body).unwrap(); -} + validate_random!(f32); + validate_random!(f16); + validate_random!(bf16); +}
\ No newline at end of file |