diff options
author | Niklas Hallqvist <niklas+github@appli.se> | 2024-03-08 16:11:50 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-08 16:11:50 +0100 |
commit | be5b68cd0ba49424b0b100c0ea48ad35b2bd67b9 (patch) | |
tree | f69eca3ba185b357c85a97bd63b6571cfd42c43e /candle-metal-kernels | |
parent | ea984d04210cf882953d1149a5bbc6b66f4157fb (diff) | |
download | candle-be5b68cd0ba49424b0b100c0ea48ad35b2bd67b9.tar.gz candle-be5b68cd0ba49424b0b100c0ea48ad35b2bd67b9.tar.bz2 candle-be5b68cd0ba49424b0b100c0ea48ad35b2bd67b9.zip |
Metal random-generation bug fixes (#1811)
* use_resource API misunderstood. It is not additive. Several usages must be bit-ORed together.
* The seeding was incorrect and used the address instead of the value of the passed in seed.
* Add a check that likely exhibits failure to update the seed between generation of random tensors.
* Buffer overrun, the length given to the std::ptr::copy call was in bytes, and not 32-bit units.
* By default seed the RNG with a time-based value, so that different runs may produce different output, just like the CPU engine.
Use device.set_seed if determinism is warranted.
* Revert "By default seed the RNG with a time-based value, so that different runs may produce different output, just like the CPU engine. Use device.set_seed if determinism is warranted."
This reverts commit d7302de9
Discussion in https://github.com/huggingface/candle/pull/1811#issuecomment-1983079119
* The Metal random kernel failed to set element N/2 of tensors with N elements, N being even. The reason was that all threads but thread 0 all created 2 random samples, but thread 0 only one, i.e. an odd number. In order to produce an even number of samples, the early termination of thread 0 should only everr occur for odd sized tensors.
* Add a test catching any deterministic tensor element in rand and randn output.
---------
Co-authored-by: niklas <niklas@appli.se>
Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 12 | ||||
-rw-r--r-- | candle-metal-kernels/src/random.metal | 24 |
2 files changed, 24 insertions, 12 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 33bc3453..47ce7e96 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1558,8 +1558,10 @@ pub fn call_random_uniform( set_params!(encoder, (length, min, max, seed, buffer)); - encoder.use_resource(seed, metal::MTLResourceUsage::Read); - encoder.use_resource(seed, metal::MTLResourceUsage::Write); + encoder.use_resource( + seed, + metal::MTLResourceUsage::Read | metal::MTLResourceUsage::Write, + ); encoder.use_resource(buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1589,8 +1591,10 @@ pub fn call_random_normal( set_params!(encoder, (length, mean, stddev, seed, buffer)); - encoder.use_resource(seed, metal::MTLResourceUsage::Read); - encoder.use_resource(seed, metal::MTLResourceUsage::Write); + encoder.use_resource( + seed, + metal::MTLResourceUsage::Read | metal::MTLResourceUsage::Write, + ); encoder.use_resource(buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); diff --git a/candle-metal-kernels/src/random.metal b/candle-metal-kernels/src/random.metal index a7e48393..c1a94199 100644 --- a/candle-metal-kernels/src/random.metal +++ b/candle-metal-kernels/src/random.metal @@ -123,16 +123,20 @@ template<typename T> METAL_FUNC void rand_uniform( return; } + // Evenly sized vectors need an offset when writing the mirror element. + uint off = 1 - size % 2; float diff = abs(min - max); - HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1}); + uint s = atomic_load_explicit(seed, memory_order_relaxed); + HybridTaus rng = HybridTaus::init({ulong(s), tid, 1, 1}); out[tid] = static_cast<T>(rng.rand() * diff + min); if (tid == 0) { atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed); - // Return early if tid == 0, otherwise we will write to out[size]. - return; + // Return early if tid == 0 && off == 0, otherwise we will write to out[size]. + if (off == 0) + return; } // Use symmetry to fill the other half of the array. - out[size - tid] = static_cast<T>(rng.rand() * diff + min); + out[size - off - tid] = static_cast<T>(rng.rand() * diff + min); } // Create Gaussian normal distribution using Box-Muller transform: @@ -148,7 +152,10 @@ template<typename T> METAL_FUNC void normal( if (tid >= size) { return; } - HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1}); + // Evenly sized vectors need an offset when writing the mirror element. + uint off = 1 - size % 2; + uint s = atomic_load_explicit(seed, memory_order_relaxed); + HybridTaus rng = HybridTaus::init({ulong(s), tid, 1, 1}); float u1 = rng.rand(); float u2 = rng.rand(); @@ -162,11 +169,12 @@ template<typename T> METAL_FUNC void normal( if (tid == 0) { atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed); - // Return early if tid == 0, otherwise we will write to out[size]. - return; + // Return early if tid == 0 && off == 0, otherwise we will write to out[size]. + if (off == 0) + return; } // Use symmetry to fill the other half of the array. - out[size - tid] = static_cast<T>(z1); + out[size - off - tid] = static_cast<T>(z1); } #define UNIFORM_OP(NAME, T) \ |