summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/Cargo.toml4
-rw-r--r--candle-core/benches/random.rs66
-rw-r--r--candle-core/src/metal_backend.rs70
-rw-r--r--candle-metal-kernels/src/lib.rs77
-rw-r--r--candle-metal-kernels/src/random.metal188
-rw-r--r--candle-metal-kernels/src/tests.rs97
6 files changed, 484 insertions, 18 deletions
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index 93b718a3..3fae7f07 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -49,3 +49,7 @@ metal = ["dep:metal", "dep:candle-metal-kernels"]
name = "bench_main"
harness = false
+[[bench]]
+name = "random"
+harness = false
+
diff --git a/candle-core/benches/random.rs b/candle-core/benches/random.rs
new file mode 100644
index 00000000..781d8b39
--- /dev/null
+++ b/candle-core/benches/random.rs
@@ -0,0 +1,66 @@
+use candle_core::{DType, Device, Tensor};
+use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput};
+use std::time::Instant;
+
+fn rand_uniform(a: &Tensor) {
+ a.rand_like(0.0, 1.0).unwrap();
+}
+
+fn rand_normal(a: &Tensor) {
+ a.randn_like(100.0, 15.0).unwrap();
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let b = 1;
+
+ let rows = 2048;
+ let cols = 2048;
+
+ let device = Device::new_metal(0).unwrap();
+ let device2 = device.clone();
+ let dtype = DType::F32;
+ let tensor = Tensor::zeros((b, rows, cols), dtype, &device).unwrap();
+
+ let flops = b * rows * cols;
+
+ let mut group = c.benchmark_group("metal_random_uniform");
+ group.throughput(Throughput::Bytes(flops as u64));
+ group.bench_function("iter", move |benches| {
+ benches.iter_custom(|iters| {
+ let start = Instant::now();
+ for _i in 0..iters {
+ rand_uniform(black_box(&tensor));
+ }
+ if let Device::Metal(device) = &device {
+ device.wait_until_completed().unwrap();
+ } else {
+ panic!("Expected metal device");
+ }
+ start.elapsed()
+ })
+ });
+ group.finish();
+
+ let tensor = Tensor::zeros((b, rows, cols), dtype, &device2).unwrap();
+
+ let mut group = c.benchmark_group("metal_random_normal");
+ group.throughput(Throughput::Bytes(flops as u64));
+ group.bench_function("iter", move |benches| {
+ benches.iter_custom(|iters| {
+ let start = Instant::now();
+ for _i in 0..iters {
+ rand_normal(black_box(&tensor));
+ }
+ if let Device::Metal(device) = &device2 {
+ device.wait_until_completed().unwrap();
+ } else {
+ panic!("Expected metal device");
+ }
+ start.elapsed()
+ })
+ });
+ group.finish();
+}
+
+criterion_group!(benches, criterion_benchmark);
+criterion_main!(benches);
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index c1c4aa4b..24beeb7a 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -8,7 +8,7 @@ use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
use std::path::Path;
-use std::sync::{Arc, RwLock, TryLockError};
+use std::sync::{Arc, Mutex, RwLock, TryLockError};
/// Simple way to catch lock error without
/// depending on T
@@ -106,6 +106,8 @@ pub struct MetalDevice {
/// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers
/// (strong_count = 1).
buffers: AllocatedBuffers,
+ /// Seed for random number generation.
+ seed: Arc<Mutex<u64>>,
}
impl std::fmt::Debug for MetalDevice {
@@ -1483,6 +1485,7 @@ impl BackendDevice for MetalDevice {
Ok(val) => val.parse()?,
_ => 20,
};
+ let seed = Arc::new(Mutex::new(299792458));
Ok(Self {
device,
fence,
@@ -1492,13 +1495,10 @@ impl BackendDevice for MetalDevice {
compute_per_buffer,
buffers,
kernels,
+ seed,
})
}
- fn set_seed(&self, _seed: u64) -> Result<()> {
- crate::bail!("Metal set_seed not implemented")
- }
-
fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Metal {
gpu_id: self.registry_id() as usize,
@@ -1551,12 +1551,31 @@ impl BackendDevice for MetalDevice {
&self,
shape: &Shape,
dtype: DType,
- mean: f64,
- stddev: f64,
+ min: f64,
+ max: f64,
) -> Result<Self::Storage> {
- // TODO is there a better way ?
- let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
- self.storage_from_cpu_storage(&cpu_storage)
+ let name = match dtype {
+ DType::F32 => "rand_uniform_f32",
+ DType::F16 => "rand_uniform_f16",
+ DType::BF16 => "rand_uniform_bf16",
+ dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
+ };
+ let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_uniform")?;
+ let command_buffer = self.command_buffer()?;
+ candle_metal_kernels::call_random_uniform(
+ &self.device,
+ &command_buffer,
+ &self.kernels,
+ name,
+ *self.seed.lock().unwrap(),
+ min as f32,
+ max as f32,
+ shape.elem_count(),
+ &buffer,
+ )
+ .map_err(MetalError::from)?;
+
+ Ok(Self::Storage::new(buffer, self.clone(), dtype))
}
fn rand_normal(
@@ -1566,9 +1585,34 @@ impl BackendDevice for MetalDevice {
mean: f64,
stddev: f64,
) -> Result<Self::Storage> {
- // TODO is there a better way ?
- let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
- self.storage_from_cpu_storage(&cpu_storage)
+ let name = match dtype {
+ DType::F32 => "rand_normal_f32",
+ DType::F16 => "rand_normal_f16",
+ DType::BF16 => "rand_normal_bf16",
+ dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
+ };
+ let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_normal")?;
+ let command_buffer = self.command_buffer()?;
+ candle_metal_kernels::call_random_normal(
+ &self.device,
+ &command_buffer,
+ &self.kernels,
+ name,
+ *self.seed.lock().unwrap(),
+ mean as f32,
+ stddev as f32,
+ shape.elem_count(),
+ &buffer,
+ )
+ .map_err(MetalError::from)?;
+
+ Ok(Self::Storage::new(buffer, self.clone(), dtype))
+ }
+
+ fn set_seed(&self, seed: u64) -> Result<()> {
+ let mut s = self.seed.try_lock().map_err(MetalError::from)?;
+ *s = seed;
+ Ok(())
}
}
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 5d34f61a..75f0286d 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -12,8 +12,9 @@ const UNARY: &str = include_str!("unary.metal");
const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal");
-const REDUCE: &str = include_str!("reduce.metal");
const CONV: &str = include_str!("conv.metal");
+const REDUCE: &str = include_str!("reduce.metal");
+const RANDOM: &str = include_str!("random.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
/// Most kernels apply similarly across the tensors
@@ -45,7 +46,7 @@ fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64,
/// Helper functions to create the various objects on the compute command encoder
/// on a single line.
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
-trait EncoderParam {
+pub trait EncoderParam {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
}
macro_rules! primitive {
@@ -61,8 +62,10 @@ macro_rules! primitive {
}
};
}
+primitive!(bool);
primitive!(usize);
primitive!(u32);
+primitive!(u64);
primitive!(f32);
impl<T> EncoderParam for &[T] {
@@ -117,6 +120,7 @@ pub enum Source {
Reduce,
Mfa,
Conv,
+ Random,
}
macro_rules! ops{
@@ -239,6 +243,7 @@ impl Kernels {
Source::Cast => CAST,
Source::Reduce => REDUCE,
Source::Conv => CONV,
+ Source::Random => RANDOM,
Source::Mfa => panic!("Invalid lib"),
}
}
@@ -1421,7 +1426,6 @@ pub fn call_gemm(
height: 1,
depth: 1,
};
- // println!("grid size {grid_size:?} group size {group_size:?}");
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
@@ -1577,5 +1581,72 @@ fn divide(m: usize, b: usize) -> NSUInteger {
((m + b - 1) / b) as NSUInteger
}
+#[allow(clippy::too_many_arguments)]
+pub fn call_random_uniform(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ seed: u64,
+ min: f32,
+ max: f32,
+ length: usize,
+ buffer: &Buffer,
+) -> Result<(), MetalKernelError> {
+ if min >= max {
+ return Err(MetalKernelError::LoadLibraryError(
+ "min must be less than max".to_string(),
+ ));
+ }
+ let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
+ let encoder = command_buffer.new_compute_command_encoder();
+
+ let odd = (length % 2 != 0) as usize;
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
+
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(encoder, (length, seed, min, max, buffer));
+
+ encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
+pub fn call_random_normal(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ seed: u64,
+ mean: f32,
+ stddev: f32,
+ length: usize,
+ buffer: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
+ let encoder = command_buffer.new_compute_command_encoder();
+
+ let odd = (length % 2 != 0) as usize;
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
+
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(encoder, (length, seed, mean, stddev, buffer));
+
+ encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+
+ Ok(())
+}
+
#[cfg(test)]
mod tests;
diff --git a/candle-metal-kernels/src/random.metal b/candle-metal-kernels/src/random.metal
new file mode 100644
index 00000000..5369e8e2
--- /dev/null
+++ b/candle-metal-kernels/src/random.metal
@@ -0,0 +1,188 @@
+#include <metal_stdlib>
+using namespace metal;
+
+// Constants
+// 2^32 and 1/2^32. Useful for converting between float and uint.
+static constexpr constant ulong UNIF01_NORM32 = 4294967296;
+static constexpr constant float UNIF01_INV32 = 2.328306436538696289e-10;
+// 2 * pi
+static constexpr constant float TWO_PI = 2.0 * M_PI_F;
+static constexpr constant int3 S1 = {13, 19, 12};
+static constexpr constant int3 S2 = {2, 25, 4};
+static constexpr constant int3 S3 = {3, 11, 17};
+
+static constexpr constant uint64_t PHI[16] = {
+ 0x9E3779B97F4A7C15,
+ 0xF39CC0605CEDC834,
+ 0x1082276BF3A27251,
+ 0xF86C6A11D0C18E95,
+ 0x2767F0B153D27B7F,
+ 0x0347045B5BF1827F,
+ 0x01886F0928403002,
+ 0xC1D64BA40F335E36,
+ 0xF06AD7AE9717877E,
+ 0x85839D6EFFBD7DC6,
+ 0x64D325D1C5371682,
+ 0xCADD0CCCFDFFBBE1,
+ 0x626E33B8D04B4331,
+ 0xBBF73C790D94F79D,
+ 0x471C4AB3ED3D82A5,
+ 0xFEC507705E4AE6E5,
+};
+
+// Combined Tausworthe and LCG Random Number Generator.
+// https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-37-efficient-random-number-generation-and-application
+// https://indico.cern.ch/event/93877/contributions/2118070/attachments/1104200/1575343/acat3_revised_final.pdf
+struct HybridTaus {
+
+ float state;
+
+ HybridTaus() thread = default;
+ HybridTaus() threadgroup = default;
+ HybridTaus() device = default;
+ HybridTaus() constant = default;
+
+ // Generate seeds for each thread.
+ METAL_FUNC static uint4 seed_per_thread(const ulong4 seeds) {
+ return uint4(ulong4(seeds) * ulong4(PHI[0], PHI[1], PHI[2], PHI[3]) * ulong4(1099087573UL));
+ }
+
+ // Tausworthe generator.
+ METAL_FUNC static uint taus(const uint z, const int3 s, const uint M) {
+ uint b = (((z << s.x) ^ z) >> s.y);
+ return (((z & M) << s.z) ^ b);
+ }
+
+ // LCG generator.
+ METAL_FUNC static uint lcg(const uint z) {
+ return (1664525 * z + 1013904223UL);
+ }
+
+ // Initialize the RNG state.
+ METAL_FUNC static HybridTaus init(const ulong4 seeds) {
+ uint4 seed = seed_per_thread(seeds);
+
+ // Seed #1
+ uint z1 = taus(seed.x, S1, 4294967294UL);
+ uint z2 = taus(seed.y, S2, 4294967288UL);
+ uint z3 = taus(seed.z, S3, 4294967280UL);
+ uint z4 = lcg(seed.x);
+
+ // Seed #2
+ uint r1 = (z1^z2^z3^z4^seed.y);
+ z1 = taus(r1, S1, 429496729UL);
+ z2 = taus(r1, S2, 4294967288UL);
+ z3 = taus(r1, S3, 429496280UL);
+ z4 = lcg(r1);
+
+ // Seed #3
+ r1 = (z1^z2^z3^z4^seed.z);
+ z1 = taus(r1, S1, 429496729UL);
+ z2 = taus(r1, S2, 4294967288UL);
+ z3 = taus(r1, S3, 429496280UL);
+ z4 = lcg(r1);
+
+ // Seed #4
+ r1 = (z1^z2^z3^z4^seed.w);
+ z1 = taus(r1, S1, 429496729UL);
+ z2 = taus(r1, S2, 4294967288UL);
+ z3 = taus(r1, S3, 429496280UL);
+ z4 = lcg(r1);
+
+ HybridTaus rng;
+ rng.state = (z1^z2^z3^z4) * UNIF01_INV32;
+ return rng;
+ }
+
+ METAL_FUNC float rand() {
+ uint seed = this->state * UNIF01_NORM32;
+ uint z1 = taus(seed, S1, 429496729UL);
+ uint z2 = taus(seed, S2, 4294967288UL);
+ uint z3 = taus(seed, S3, 429496280UL);
+ uint z4 = lcg(seed);
+
+ thread float result = this->state;
+ this->state = (z1^z2^z3^z4) * UNIF01_INV32;
+ return result;
+ }
+};
+
+template<typename T> METAL_FUNC void rand_uniform(
+ constant size_t &size,
+ constant ulong &seed,
+ constant float &min,
+ constant float &max,
+ device T *out,
+ uint tid [[thread_position_in_grid]]
+) {
+ if (tid >= size) {
+ return;
+ }
+ float diff = max - min;
+ HybridTaus rng = HybridTaus::init({seed, tid, 1, 1});
+ out[tid] = static_cast<T>(rng.rand() * diff + min);
+ out[size - tid] = static_cast<T>(rng.rand() * diff + min);
+}
+
+// Create Gaussian normal distribution using Box-Muller transform:
+// https://en.wikipedia.org/wiki/Box–Muller_transform
+template<typename T> METAL_FUNC void normal(
+ constant size_t &size,
+ constant ulong &seed,
+ constant float &mean,
+ constant float &stddev,
+ device T *out,
+ uint tid [[thread_position_in_grid]]
+) {
+ if (tid >= size) {
+ return;
+ }
+ HybridTaus rng = HybridTaus::init({seed, tid, 1, 1});
+ float u1 = rng.rand();
+ float u2 = rng.rand();
+
+ float cosval;
+ float sinval = sincos(u1 * TWO_PI, cosval);
+ float mag = stddev * sqrt(-2.0 * log(u1));
+ float z0 = mag * cosval + mean;
+ float z1 = mag * sinval + mean;
+
+ out[tid] = static_cast<T>(z0);
+ out[size - tid] = static_cast<T>(z1);
+}
+
+#define UNIFORM_OP(NAME, T) \
+kernel void rand_uniform_##NAME( \
+ constant size_t &size, \
+ constant ulong &seed, \
+ constant float &min, \
+ constant float &max, \
+ device T *out, \
+ uint tid [[thread_position_in_grid]] \
+) { \
+ rand_uniform<T>(size, seed, min, max, out, tid); \
+} \
+
+#define NORMAL_OP(NAME, T) \
+kernel void rand_normal_##NAME( \
+ constant size_t &size, \
+ constant ulong &seed, \
+ constant float &mean, \
+ constant float &stddev, \
+ device T *out, \
+ uint tid [[thread_position_in_grid]] \
+) { \
+ normal<T>(size, seed, mean, stddev, out, tid); \
+} \
+
+
+#define RANDOM_OPS(NAME, T) \
+UNIFORM_OP(NAME, T) \
+NORMAL_OP(NAME, T) \
+
+RANDOM_OPS(f32, float)
+RANDOM_OPS(f16, half)
+
+#if __METAL_VERSION__ >= 310
+RANDOM_OPS(bf16, bfloat)
+#endif
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index c955abca..067dece8 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -11,7 +11,7 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
let options = MTLResourceOptions::StorageModeManaged;
- let ptr = data.as_ptr() as *const core::ffi::c_void;
+ let ptr = data.as_ptr() as *const c_void;
let size = (data.len() * std::mem::size_of::<T>()) as u64;
device.new_buffer_with_data(ptr, size, options)
}
@@ -590,7 +590,6 @@ fn softmax() {
}
let results = run_softmax(&v, last_dim, "softmax_f32");
let results = approx(results, 4);
- println!("{results:?}");
assert_eq!(
results.iter().map(|&s| s.round() as usize).sum::<usize>(),
n
@@ -806,3 +805,97 @@ fn gemm() {
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
);
}
+
+fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b: f32) -> Vec<T> {
+ let device = device();
+ let fence = device.new_fence();
+ let kernels = Kernels::new(fence);
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let options = MTLResourceOptions::StorageModeManaged;
+ let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
+
+ if name.starts_with("rand_uniform") {
+ call_random_uniform(
+ &device,
+ command_buffer,
+ &kernels,
+ name,
+ seed,
+ a,
+ b,
+ length,
+ &output,
+ )
+ .unwrap();
+ } else {
+ call_random_normal(
+ &device,
+ command_buffer,
+ &kernels,
+ name,
+ seed,
+ a,
+ b,
+ length,
+ &output,
+ )
+ .unwrap();
+ }
+
+
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ read_to_vec(&output, length)
+}
+
+#[test]
+fn random() {
+
+ fn calc_mean(data: &[f32]) -> f32 {
+ let sum = data.iter().sum::<f32>() as f32;
+ let count = data.len();
+ assert!(count > 0);
+ sum / count as f32
+ }
+
+ fn calc_stddev(data: &[f32]) -> f32 {
+ let mean = calc_mean(data);
+ let count = data.len();
+ assert!(count > 0);
+
+ let variance = data.iter().map(|value| {
+ let diff = mean - (*value as f32);
+ diff * diff
+ }).sum::<f32>() / count as f32;
+
+ variance.sqrt()
+ }
+
+ let shape = vec![1024, 10];
+
+ let length = shape.iter().product::<usize>();
+ let seed = 299792458;
+
+ let min = -30.0;
+ let max = 30.0;
+ let mean = 100.0;
+ let stddev = 50.0;
+
+ macro_rules! validate_random {
+ ($type:ty) => {
+ let results: Vec<f32> = run_random::<$type>(concat!("rand_uniform_", stringify!($type)), seed, length, min, max).into_iter().map(f32::from).collect();
+ results.iter().for_each(|v| assert!(*v >= min && *v <= max));
+ assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
+
+ let results: Vec<f32> = run_random::<$type>(concat!("rand_normal_", stringify!($type)), seed, length, mean, stddev).into_iter().map(f32::from).collect();
+ assert!((calc_mean(&results) - mean).abs() < mean / 10.0);
+ assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0);
+ };
+ }
+
+ validate_random!(f32);
+ validate_random!(f16);
+ validate_random!(bf16);
+} \ No newline at end of file