diff options
author | Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> | 2024-01-16 19:11:31 +0100 |
---|---|---|
committer | Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> | 2024-01-17 09:12:44 +0100 |
commit | 86a8e58897f012445de2f35318b19a89ebfaa327 (patch) | |
tree | bc241576d03ff07ea38ef145ddc07344838cd1b8 /candle-metal-kernels | |
parent | 79478ff5a1eab89f6e638ad7e7abd587b0f5b167 (diff) | |
download | candle-86a8e58897f012445de2f35318b19a89ebfaa327.tar.gz candle-86a8e58897f012445de2f35318b19a89ebfaa327.tar.bz2 candle-86a8e58897f012445de2f35318b19a89ebfaa327.zip |
Update metal random kernel and set_seed method
* set_seed via buffer content pointer copy + did_modify_range
* ensure random.metal kernel does not write outside of buffer range when tid==0
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/random.metal | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/candle-metal-kernels/src/random.metal b/candle-metal-kernels/src/random.metal index 5eae2715..a7e48393 100644 --- a/candle-metal-kernels/src/random.metal +++ b/candle-metal-kernels/src/random.metal @@ -14,6 +14,7 @@ static constexpr constant int3 S1 = {13, 19, 12}; static constexpr constant int3 S2 = {2, 25, 4}; static constexpr constant int3 S3 = {3, 11, 17}; +// Used to prevent bad seeds. static constexpr constant uint64_t PHI[16] = { 0x9E3779B97F4A7C15, 0xF39CC0605CEDC834, @@ -110,10 +111,6 @@ struct HybridTaus { } }; -METAL_FUNC float absdiff(float x, float y) { - return abs(x - y); -} - template<typename T> METAL_FUNC void rand_uniform( constant size_t &size, constant float &min, @@ -126,14 +123,16 @@ template<typename T> METAL_FUNC void rand_uniform( return; } - float diff = absdiff(min, max); + float diff = abs(min - max); HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1}); out[tid] = static_cast<T>(rng.rand() * diff + min); - out[size - tid] = static_cast<T>(rng.rand() * diff + min); - if (tid == 0) { atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed); + // Return early if tid == 0, otherwise we will write to out[size]. + return; } + // Use symmetry to fill the other half of the array. + out[size - tid] = static_cast<T>(rng.rand() * diff + min); } // Create Gaussian normal distribution using Box-Muller transform: @@ -160,11 +159,14 @@ template<typename T> METAL_FUNC void normal( float z1 = mag * sinval + mean; out[tid] = static_cast<T>(z0); - out[size - tid] = static_cast<T>(z1); if (tid == 0) { atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed); + // Return early if tid == 0, otherwise we will write to out[size]. + return; } + // Use symmetry to fill the other half of the array. + out[size - tid] = static_cast<T>(z1); } #define UNIFORM_OP(NAME, T) \ |