summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorIvar Flakstad <69173633+ivarflakstad@users.noreply.github.com>2024-01-16 19:11:31 +0100
committerIvar Flakstad <69173633+ivarflakstad@users.noreply.github.com>2024-01-17 09:12:44 +0100
commit86a8e58897f012445de2f35318b19a89ebfaa327 (patch)
treebc241576d03ff07ea38ef145ddc07344838cd1b8 /candle-metal-kernels
parent79478ff5a1eab89f6e638ad7e7abd587b0f5b167 (diff)
downloadcandle-86a8e58897f012445de2f35318b19a89ebfaa327.tar.gz
candle-86a8e58897f012445de2f35318b19a89ebfaa327.tar.bz2
candle-86a8e58897f012445de2f35318b19a89ebfaa327.zip
Update metal random kernel and set_seed method
* set_seed via buffer content pointer copy + did_modify_range * ensure random.metal kernel does not write outside of buffer range when tid==0
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/random.metal18
1 files changed, 10 insertions, 8 deletions
diff --git a/candle-metal-kernels/src/random.metal b/candle-metal-kernels/src/random.metal
index 5eae2715..a7e48393 100644
--- a/candle-metal-kernels/src/random.metal
+++ b/candle-metal-kernels/src/random.metal
@@ -14,6 +14,7 @@ static constexpr constant int3 S1 = {13, 19, 12};
static constexpr constant int3 S2 = {2, 25, 4};
static constexpr constant int3 S3 = {3, 11, 17};
+// Used to prevent bad seeds.
static constexpr constant uint64_t PHI[16] = {
0x9E3779B97F4A7C15,
0xF39CC0605CEDC834,
@@ -110,10 +111,6 @@ struct HybridTaus {
}
};
-METAL_FUNC float absdiff(float x, float y) {
- return abs(x - y);
-}
-
template<typename T> METAL_FUNC void rand_uniform(
constant size_t &size,
constant float &min,
@@ -126,14 +123,16 @@ template<typename T> METAL_FUNC void rand_uniform(
return;
}
- float diff = absdiff(min, max);
+ float diff = abs(min - max);
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
out[tid] = static_cast<T>(rng.rand() * diff + min);
- out[size - tid] = static_cast<T>(rng.rand() * diff + min);
-
if (tid == 0) {
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
+ // Return early if tid == 0, otherwise we will write to out[size].
+ return;
}
+ // Use symmetry to fill the other half of the array.
+ out[size - tid] = static_cast<T>(rng.rand() * diff + min);
}
// Create Gaussian normal distribution using Box-Muller transform:
@@ -160,11 +159,14 @@ template<typename T> METAL_FUNC void normal(
float z1 = mag * sinval + mean;
out[tid] = static_cast<T>(z0);
- out[size - tid] = static_cast<T>(z1);
if (tid == 0) {
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
+ // Return early if tid == 0, otherwise we will write to out[size].
+ return;
}
+ // Use symmetry to fill the other half of the array.
+ out[size - tid] = static_cast<T>(z1);
}
#define UNIFORM_OP(NAME, T) \