diff options
author | Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> | 2024-01-14 18:10:54 +0100 |
---|---|---|
committer | Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> | 2024-01-15 11:58:25 +0100 |
commit | 79478ff5a1eab89f6e638ad7e7abd587b0f5b167 (patch) | |
tree | d3f4d77d1364d6a221ef9e15f8982cc0bb09d4d0 /candle-metal-kernels | |
parent | ecf88a6d381e40c8db1c643dff2753fd877fae92 (diff) | |
download | candle-79478ff5a1eab89f6e638ad7e7abd587b0f5b167.tar.gz candle-79478ff5a1eab89f6e638ad7e7abd587b0f5b167.tar.bz2 candle-79478ff5a1eab89f6e638ad7e7abd587b0f5b167.zip |
Seed should be updated by random kernel result.
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 12 | ||||
-rw-r--r-- | candle-metal-kernels/src/random.metal | 36 | ||||
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 20 |
3 files changed, 48 insertions, 20 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index c427a690..6a10c333 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1587,10 +1587,10 @@ pub fn call_random_uniform( command_buffer: &CommandBufferRef, kernels: &Kernels, name: &'static str, - seed: u64, min: f32, max: f32, length: usize, + seed: &Buffer, buffer: &Buffer, ) -> Result<(), MetalKernelError> { if min >= max { @@ -1607,8 +1607,10 @@ pub fn call_random_uniform( encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, seed, min, max, buffer)); + set_params!(encoder, (length, min, max, seed, buffer)); + encoder.use_resource(seed, metal::MTLResourceUsage::Read); + encoder.use_resource(seed, metal::MTLResourceUsage::Write); encoder.use_resource(buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); @@ -1623,10 +1625,10 @@ pub fn call_random_normal( command_buffer: &CommandBufferRef, kernels: &Kernels, name: &'static str, - seed: u64, mean: f32, stddev: f32, length: usize, + seed: &Buffer, buffer: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Random, name)?; @@ -1638,8 +1640,10 @@ pub fn call_random_normal( encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, seed, mean, stddev, buffer)); + set_params!(encoder, (length, mean, stddev, seed, buffer)); + encoder.use_resource(seed, metal::MTLResourceUsage::Read); + encoder.use_resource(seed, metal::MTLResourceUsage::Write); encoder.use_resource(buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); diff --git a/candle-metal-kernels/src/random.metal b/candle-metal-kernels/src/random.metal index 5369e8e2..5eae2715 100644 --- a/candle-metal-kernels/src/random.metal +++ b/candle-metal-kernels/src/random.metal @@ -1,4 +1,7 @@ #include <metal_stdlib> +#include <metal_integer> +#include <metal_atomic> + using namespace metal; // Constants @@ -107,72 +110,85 @@ struct HybridTaus { } }; +METAL_FUNC float absdiff(float x, float y) { + return abs(x - y); +} + template<typename T> METAL_FUNC void rand_uniform( constant size_t &size, - constant ulong &seed, constant float &min, constant float &max, + device atomic_uint *seed, device T *out, uint tid [[thread_position_in_grid]] ) { if (tid >= size) { return; } - float diff = max - min; - HybridTaus rng = HybridTaus::init({seed, tid, 1, 1}); + + float diff = absdiff(min, max); + HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1}); out[tid] = static_cast<T>(rng.rand() * diff + min); out[size - tid] = static_cast<T>(rng.rand() * diff + min); + + if (tid == 0) { + atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed); + } } // Create Gaussian normal distribution using Box-Muller transform: // https://en.wikipedia.org/wiki/Box–Muller_transform template<typename T> METAL_FUNC void normal( constant size_t &size, - constant ulong &seed, constant float &mean, constant float &stddev, + device atomic_uint *seed, device T *out, uint tid [[thread_position_in_grid]] ) { if (tid >= size) { return; } - HybridTaus rng = HybridTaus::init({seed, tid, 1, 1}); + HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1}); float u1 = rng.rand(); float u2 = rng.rand(); float cosval; - float sinval = sincos(u1 * TWO_PI, cosval); + float sinval = sincos(TWO_PI * u2, cosval); float mag = stddev * sqrt(-2.0 * log(u1)); float z0 = mag * cosval + mean; float z1 = mag * sinval + mean; out[tid] = static_cast<T>(z0); out[size - tid] = static_cast<T>(z1); + + if (tid == 0) { + atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed); + } } #define UNIFORM_OP(NAME, T) \ kernel void rand_uniform_##NAME( \ constant size_t &size, \ - constant ulong &seed, \ constant float &min, \ constant float &max, \ + device atomic_uint *seed, \ device T *out, \ uint tid [[thread_position_in_grid]] \ ) { \ - rand_uniform<T>(size, seed, min, max, out, tid); \ + rand_uniform<T>(size, min, max, seed, out, tid); \ } \ #define NORMAL_OP(NAME, T) \ kernel void rand_normal_##NAME( \ constant size_t &size, \ - constant ulong &seed, \ constant float &mean, \ constant float &stddev, \ + device atomic_uint *seed, \ device T *out, \ uint tid [[thread_position_in_grid]] \ ) { \ - normal<T>(size, seed, mean, stddev, out, tid); \ + normal<T>(size, mean, stddev, seed, out, tid); \ } \ diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 775ee0fa..2831a386 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -938,14 +938,21 @@ fn gemm() { ); } -fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b: f32) -> Vec<T> { +fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> { let device = device(); let fence = device.new_fence(); let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; - let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options); + let output = device.new_buffer((length * core::mem::size_of::<T>()) as NSUInteger, options); + + let seed = device.new_buffer_with_data( + &seed as *const u32 as *const core::ffi::c_void, + std::mem::size_of::<u32>() as NSUInteger, + options, + ); if name.starts_with("rand_uniform") { call_random_uniform( @@ -953,10 +960,10 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b: command_buffer, &kernels, name, - seed, a, b, length, + &seed, &output, ) .unwrap(); @@ -966,15 +973,14 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b: command_buffer, &kernels, name, - seed, a, b, length, + &seed, &output, ) .unwrap(); } - command_buffer.commit(); command_buffer.wait_until_completed(); @@ -1029,7 +1035,9 @@ fn random() { .into_iter() .map(f32::from) .collect(); - results.iter().for_each(|v| assert!(*v >= min && *v <= max)); + results.iter().for_each(|v| { + assert!(*v >= min && *v <= max); + }); assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0); let results: Vec<f32> = run_random::<$type>( |