summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorIvar Flakstad <69173633+ivarflakstad@users.noreply.github.com>2024-01-14 18:10:54 +0100
committerIvar Flakstad <69173633+ivarflakstad@users.noreply.github.com>2024-01-15 11:58:25 +0100
commit79478ff5a1eab89f6e638ad7e7abd587b0f5b167 (patch)
treed3f4d77d1364d6a221ef9e15f8982cc0bb09d4d0 /candle-metal-kernels
parentecf88a6d381e40c8db1c643dff2753fd877fae92 (diff)
downloadcandle-79478ff5a1eab89f6e638ad7e7abd587b0f5b167.tar.gz
candle-79478ff5a1eab89f6e638ad7e7abd587b0f5b167.tar.bz2
candle-79478ff5a1eab89f6e638ad7e7abd587b0f5b167.zip
Seed should be updated by random kernel result.
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs12
-rw-r--r--candle-metal-kernels/src/random.metal36
-rw-r--r--candle-metal-kernels/src/tests.rs20
3 files changed, 48 insertions, 20 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index c427a690..6a10c333 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1587,10 +1587,10 @@ pub fn call_random_uniform(
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
- seed: u64,
min: f32,
max: f32,
length: usize,
+ seed: &Buffer,
buffer: &Buffer,
) -> Result<(), MetalKernelError> {
if min >= max {
@@ -1607,8 +1607,10 @@ pub fn call_random_uniform(
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
- set_params!(encoder, (length, seed, min, max, buffer));
+ set_params!(encoder, (length, min, max, seed, buffer));
+ encoder.use_resource(seed, metal::MTLResourceUsage::Read);
+ encoder.use_resource(seed, metal::MTLResourceUsage::Write);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
@@ -1623,10 +1625,10 @@ pub fn call_random_normal(
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
- seed: u64,
mean: f32,
stddev: f32,
length: usize,
+ seed: &Buffer,
buffer: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
@@ -1638,8 +1640,10 @@ pub fn call_random_normal(
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
- set_params!(encoder, (length, seed, mean, stddev, buffer));
+ set_params!(encoder, (length, mean, stddev, seed, buffer));
+ encoder.use_resource(seed, metal::MTLResourceUsage::Read);
+ encoder.use_resource(seed, metal::MTLResourceUsage::Write);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
diff --git a/candle-metal-kernels/src/random.metal b/candle-metal-kernels/src/random.metal
index 5369e8e2..5eae2715 100644
--- a/candle-metal-kernels/src/random.metal
+++ b/candle-metal-kernels/src/random.metal
@@ -1,4 +1,7 @@
#include <metal_stdlib>
+#include <metal_integer>
+#include <metal_atomic>
+
using namespace metal;
// Constants
@@ -107,72 +110,85 @@ struct HybridTaus {
}
};
+METAL_FUNC float absdiff(float x, float y) {
+ return abs(x - y);
+}
+
template<typename T> METAL_FUNC void rand_uniform(
constant size_t &size,
- constant ulong &seed,
constant float &min,
constant float &max,
+ device atomic_uint *seed,
device T *out,
uint tid [[thread_position_in_grid]]
) {
if (tid >= size) {
return;
}
- float diff = max - min;
- HybridTaus rng = HybridTaus::init({seed, tid, 1, 1});
+
+ float diff = absdiff(min, max);
+ HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
out[tid] = static_cast<T>(rng.rand() * diff + min);
out[size - tid] = static_cast<T>(rng.rand() * diff + min);
+
+ if (tid == 0) {
+ atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
+ }
}
// Create Gaussian normal distribution using Box-Muller transform:
// https://en.wikipedia.org/wiki/Box–Muller_transform
template<typename T> METAL_FUNC void normal(
constant size_t &size,
- constant ulong &seed,
constant float &mean,
constant float &stddev,
+ device atomic_uint *seed,
device T *out,
uint tid [[thread_position_in_grid]]
) {
if (tid >= size) {
return;
}
- HybridTaus rng = HybridTaus::init({seed, tid, 1, 1});
+ HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
float u1 = rng.rand();
float u2 = rng.rand();
float cosval;
- float sinval = sincos(u1 * TWO_PI, cosval);
+ float sinval = sincos(TWO_PI * u2, cosval);
float mag = stddev * sqrt(-2.0 * log(u1));
float z0 = mag * cosval + mean;
float z1 = mag * sinval + mean;
out[tid] = static_cast<T>(z0);
out[size - tid] = static_cast<T>(z1);
+
+ if (tid == 0) {
+ atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
+ }
}
#define UNIFORM_OP(NAME, T) \
kernel void rand_uniform_##NAME( \
constant size_t &size, \
- constant ulong &seed, \
constant float &min, \
constant float &max, \
+ device atomic_uint *seed, \
device T *out, \
uint tid [[thread_position_in_grid]] \
) { \
- rand_uniform<T>(size, seed, min, max, out, tid); \
+ rand_uniform<T>(size, min, max, seed, out, tid); \
} \
#define NORMAL_OP(NAME, T) \
kernel void rand_normal_##NAME( \
constant size_t &size, \
- constant ulong &seed, \
constant float &mean, \
constant float &stddev, \
+ device atomic_uint *seed, \
device T *out, \
uint tid [[thread_position_in_grid]] \
) { \
- normal<T>(size, seed, mean, stddev, out, tid); \
+ normal<T>(size, mean, stddev, seed, out, tid); \
} \
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index 775ee0fa..2831a386 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -938,14 +938,21 @@ fn gemm() {
);
}
-fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b: f32) -> Vec<T> {
+fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
let device = device();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
+
let options = MTLResourceOptions::StorageModeManaged;
- let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
+ let output = device.new_buffer((length * core::mem::size_of::<T>()) as NSUInteger, options);
+
+ let seed = device.new_buffer_with_data(
+ &seed as *const u32 as *const core::ffi::c_void,
+ std::mem::size_of::<u32>() as NSUInteger,
+ options,
+ );
if name.starts_with("rand_uniform") {
call_random_uniform(
@@ -953,10 +960,10 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b:
command_buffer,
&kernels,
name,
- seed,
a,
b,
length,
+ &seed,
&output,
)
.unwrap();
@@ -966,15 +973,14 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b:
command_buffer,
&kernels,
name,
- seed,
a,
b,
length,
+ &seed,
&output,
)
.unwrap();
}
-
command_buffer.commit();
command_buffer.wait_until_completed();
@@ -1029,7 +1035,9 @@ fn random() {
.into_iter()
.map(f32::from)
.collect();
- results.iter().for_each(|v| assert!(*v >= min && *v <= max));
+ results.iter().for_each(|v| {
+ assert!(*v >= min && *v <= max);
+ });
assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
let results: Vec<f32> = run_random::<$type>(