diff options
25 files changed, 2775 insertions, 776 deletions
@@ -32,6 +32,7 @@ accelerate-src = { version = "0.3.2" } anyhow = { version = "1", features = ["backtrace"] } byteorder = "1.4.3" clap = { version = "4.2.4", features = ["derive"] } +criterion = { version = "0.5.1", default-features=false } cudarc = { version = "0.9.14", features = ["f16"] } gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] } hf-hub = "0.3.0" @@ -61,7 +62,7 @@ tracing-subscriber = "0.3.7" wav = "1.0.0" yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "0.6.6", default-features = false } -metal = { version = "0.27.1", features = ["mps"], package="candle-metal" } +metal = { version = "0.27.0", features = ["mps"]} [profile.release-with-debug] inherits = "release" diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 42e5be2a..52e79a5a 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -34,6 +34,8 @@ zip = { workspace = true } [dev-dependencies] anyhow = { workspace = true } clap = { workspace = true } +criterion = { workspace = true } + [features] default = [] @@ -42,3 +44,8 @@ cudnn = ["cuda", "cudarc/cudnn"] mkl = ["dep:libc", "dep:intel-mkl-src"] accelerate = ["dep:libc", "dep:accelerate-src"] metal = ["dep:metal", "dep:candle-metal-kernels"] + +[[bench]] +name = "matmul" +harness = false + diff --git a/candle-core/benches/matmul.rs b/candle-core/benches/matmul.rs new file mode 100644 index 00000000..8732f451 --- /dev/null +++ b/candle-core/benches/matmul.rs @@ -0,0 +1,43 @@ +use candle_core::{DType, Device, Tensor}; +use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use std::time::Instant; + +fn run(a: &Tensor, b: &Tensor) { + a.matmul(&b.t().unwrap()).unwrap(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let b = 1; + let m = 1; + let n = 2048; + let k = 2048; + + let device = Device::new_metal(0).unwrap(); + let dtype = DType::F32; + let lhs = Tensor::zeros((b, m, k), dtype, &device).unwrap(); + let rhs = Tensor::zeros((b, n, k), dtype, &device).unwrap(); + + let flops = b * m * n * k; + + let mut group = c.benchmark_group("matmul_metal"); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box(&lhs), black_box(&rhs)); + } + if let Device::Metal(device) = &device { + device.wait_until_completed().unwrap(); + } else { + panic!("Expected metal device"); + } + start.elapsed() + }) + }); + group.finish(); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); + diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 3eb7f8b7..1e33021b 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -201,10 +201,9 @@ impl Device { Ok(Storage::Cuda(storage)) } } - Device::Metal(_device) => { - // let storage = device.rand_uniform(shape, dtype, lo, up)?; - // Ok(Storage::Metal(storage)) - crate::bail!("Metal rand_uniform not implemented") + Device::Metal(device) => { + let storage = device.rand_uniform(shape, dtype, lo, up)?; + Ok(Storage::Metal(storage)) } } } diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 0b72f080..27b2824f 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -4,11 +4,30 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; use candle_metal_kernels::Kernels; -use core::mem; -use half::{bf16, f16}; use metal; -use metal::{Buffer, CommandQueue, MTLResourceOptions, NSUInteger}; -use std::sync::Arc; +use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; +use std::collections::HashMap; +use std::path::Path; +use std::sync::{Arc, RwLock, TryLockError}; + +/// Simple way to catch lock error without +/// depending on T +#[derive(thiserror::Error, Debug)] +pub enum LockError { + #[error("{0}")] + Poisoned(String), + #[error("Would block")] + WouldBlock, +} + +impl<T> From<TryLockError<T>> for MetalError { + fn from(value: TryLockError<T>) -> Self { + match value { + TryLockError::Poisoned(p) => MetalError::LockError(LockError::Poisoned(p.to_string())), + TryLockError::WouldBlock => MetalError::LockError(LockError::WouldBlock), + } + } +} /// Metal related errors #[derive(thiserror::Error, Debug)] @@ -24,6 +43,14 @@ pub enum MetalError { rhs_stride: Vec<usize>, mnk: (usize, usize, usize), }, + #[error("{0:?}")] + LockError(LockError), + #[error("{msg}, expected: {expected:?}, got: {got:?}")] + UnexpectedDType { + msg: &'static str, + expected: DType, + got: DType, + }, } impl From<String> for MetalError { @@ -32,11 +59,53 @@ impl From<String> for MetalError { } } +type AllocatedBuffers = Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>; + #[derive(Clone)] pub struct MetalDevice { + /// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc> device: metal::Device, + + /// Single command queue for the entire device. command_queue: metal::CommandQueue, + /// One command buffer at a time. + /// The scheduler works by allowing multiple + /// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) + /// on a single command buffer. Using a single command buffer would be fastest on the GPU but + /// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed + /// to start to work). + /// Despite what the documentation says, command buffers are NOT ordered. They are ordered + /// for their START time, but there's no guarantee that command buffer1 will finish before + /// command buffer2 starts (or there are metal bugs there) + command_buffer: Arc<RwLock<metal::CommandBuffer>>, + /// Keeps track of the current amount of compute command encoders on the current + /// command buffer + /// Arc, RwLock because of the interior mutability. + command_buffer_index: Arc<RwLock<usize>>, + /// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc) + compute_per_buffer: usize, + /// Every compute command encoder (and blit encoders) are defended with this Fence, forcing the + /// execution order to be linear. + /// It could be relaxed in some circumstances, by managing ourselves the dependencies in the + /// compute graph. + fence: metal::Fence, + /// Simple keeper struct to keep track of the already compiled kernels so we can reuse them. + /// Heavily used by [`candle_metal_kernels`], both fences need to match kernels: Arc<candle_metal_kernels::Kernels>, + /// Simple allocator struct. + /// The buffers are stored in size buckets since ML tends to use similar shapes over and over. + /// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting + /// (could be linked to FFI communication overhead). + /// + /// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the + /// graph calculation, and only we the allocator kept a reference to it, therefore it's free + /// to be reused. However, in order for this to work, we need to guarantee the order of + /// operation, so that this buffer is not being used by another kernel at the same time. + /// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things. + /// + /// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers + /// (strong_count = 1). + buffers: AllocatedBuffers, } impl std::fmt::Debug for MetalDevice { @@ -58,10 +127,47 @@ impl MetalDevice { self.registry_id() } + pub fn metal_device(&self) -> &metal::Device { + &self.device + } + pub fn command_queue(&self) -> &CommandQueue { &self.command_queue } + pub fn command_buffer(&self) -> Result<CommandBuffer> { + let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?; + let mut command_buffer = command_buffer_lock.to_owned(); + let mut index = self + .command_buffer_index + .try_write() + .map_err(MetalError::from)?; + if *index > self.compute_per_buffer { + command_buffer.commit(); + command_buffer = self.command_queue.new_command_buffer().to_owned(); + *command_buffer_lock = command_buffer.clone(); + *index = 0; + } + *index += 1; + Ok(command_buffer) + } + + pub fn wait_until_completed(&self) -> Result<()> { + let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?; + match command_buffer.status() { + metal::MTLCommandBufferStatus::Committed + | metal::MTLCommandBufferStatus::Scheduled + | metal::MTLCommandBufferStatus::Completed => { + panic!("Already committed"); + } + _ => {} + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + *command_buffer = self.command_queue.new_command_buffer().to_owned(); + Ok(()) + } + pub fn kernels(&self) -> &Kernels { &self.kernels } @@ -70,17 +176,119 @@ impl MetalDevice { &self.device } - pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { + /// Creates a new buffer (not necessarily zeroed). + /// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode) + /// This means the buffer data cannot be read on the CPU directly. + /// + /// [`name`] is only used to keep track of the resource origin in case of bugs + pub fn new_buffer( + &self, + element_count: usize, + dtype: DType, + name: &str, + ) -> Result<Arc<Buffer>> { let size = (element_count * dtype.size_in_bytes()) as NSUInteger; - self.device - .new_buffer(size, MTLResourceOptions::StorageModeManaged) + self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name) + } + + /// Creates a new buffer (not necessarily zeroed). + /// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode) + /// This means the buffer can be read on the CPU but will require manual + /// synchronization when the CPU memory is modified + /// Used as a bridge to gather data back from the GPU + pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> { + self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed") + } + + /// Creates a new buffer from data. + /// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode) + /// + /// This method will block the computation because of the + /// lack of lifetime management through the GPU. + /// Internal comment for technical details. + pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> { + let size = core::mem::size_of_val(data) as NSUInteger; + let tmp = self.device.new_buffer_with_data( + data.as_ptr() as *const core::ffi::c_void, + size, + metal::MTLResourceOptions::StorageModeManaged, + ); + let real = self.allocate_buffer( + size, + metal::MTLResourceOptions::StorageModePrivate, + "with_data", + )?; + let command_buffer = self.command_buffer()?; + command_buffer.set_label("with_data"); + let blit = command_buffer.new_blit_command_encoder(); + blit.wait_for_fence(&self.fence); + blit.set_label("with_data_blit"); + blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); + blit.update_fence(&self.fence); + blit.end_encoding(); + + // This is necessary, for mmaped safetensors + // Because of the unsafe slice cast we're doing. + // The slice might not live long enough for metal + // To actually fill the GPU buffer. + // Putting this wait forces the GPU buffer to be filled + // with the actual data allowing the CPU storage todo + // deallocate properly. + self.wait_until_completed()?; + Ok(real) + } + + /// The critical allocator algorithm + fn allocate_buffer( + &self, + size: NSUInteger, + option: MTLResourceOptions, + _name: &str, + ) -> Result<Arc<Buffer>> { + let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; + let subbuffers = buffers.entry((size, option)).or_insert(vec![]); + + for sub in &mut *subbuffers { + if Arc::strong_count(sub) == 1 { + return Ok(sub.clone()); + } + } + let new_buffer = self.device.new_buffer(size as NSUInteger, option); + let new_buffer = Arc::new(new_buffer); + subbuffers.push(new_buffer.clone()); + for subbuffers in buffers.values_mut() { + let newbuffers = subbuffers + .iter() + .filter(|s| Arc::strong_count(s) > 1) + .map(Arc::clone) + .collect(); + *subbuffers = newbuffers; + } + Ok(new_buffer) + } + + /// Create a metal GPU capture trace on [`path`]. + pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> { + let capture = metal::CaptureManager::shared(); + let descriptor = metal::CaptureDescriptor::new(); + descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); + descriptor.set_capture_device(self); + descriptor.set_output_url(path); + + capture + .start_capture(&descriptor) + .map_err(MetalError::from)?; + Ok(()) } } #[derive(Debug, Clone)] pub struct MetalStorage { - buffer: metal::Buffer, + /// The actual buffer containing the data. + buffer: Arc<metal::Buffer>, + /// a reference to the device owning this buffer device: MetalDevice, + /// The dtype is kept since buffers are untyped. dtype: DType, } @@ -108,14 +316,27 @@ impl BackendStorage for MetalStorage { self.dtype ); } + let buffer = self.device.new_buffer_managed(self.buffer.length())?; + { + let command_buffer = self.device.command_buffer()?; + command_buffer.set_label("to_cpu"); + let blit = command_buffer.new_blit_command_encoder(); + blit.set_label("blit_to_cpu"); + blit.wait_for_fence(&self.device.fence); + blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); + blit.update_fence(&self.device.fence); + blit.end_encoding(); + } + self.device.wait_until_completed()?; + match self.dtype { - DType::U8 => Ok(CpuStorage::U8(self.buffer.read_to_vec(length / size))), - DType::U32 => Ok(CpuStorage::U32(self.buffer.read_to_vec(length / size))), - DType::I64 => Ok(CpuStorage::I64(self.buffer.read_to_vec(length / size))), - DType::F16 => Ok(CpuStorage::F16(self.buffer.read_to_vec(length / size))), - DType::BF16 => Ok(CpuStorage::BF16(self.buffer.read_to_vec(length / size))), - DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(length / size))), - DType::F64 => Ok(CpuStorage::F64(self.buffer.read_to_vec(length / size))), + DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))), + DType::U32 => Ok(CpuStorage::U32(read_to_vec(&buffer, length / size))), + DType::I64 => Ok(CpuStorage::I64(read_to_vec(&buffer, length / size))), + DType::F16 => Ok(CpuStorage::F16(read_to_vec(&buffer, length / size))), + DType::BF16 => Ok(CpuStorage::BF16(read_to_vec(&buffer, length / size))), + DType::F32 => Ok(CpuStorage::F32(read_to_vec(&buffer, length / size))), + DType::F64 => Ok(CpuStorage::F64(read_to_vec(&buffer, length / size))), } } @@ -126,52 +347,152 @@ impl BackendStorage for MetalStorage { let el = shape.elem_count(); let dtype = self.dtype; - if layout.is_contiguous() || layout.start_offset() != 0 || dtype != DType::F32 { - crate::bail!("Not contiguous, non-f32 affine is not implemented yet."); + let buffer = device.new_buffer(el, self.dtype, "affine")?; + let command_buffer = self.device.command_buffer()?; + if layout.is_contiguous() && layout.start_offset() == 0 { + let name = match self.dtype { + DType::F32 => "affine_f32", + DType::F16 => "affine_f16", + dtype => crate::bail!("Affine {dtype:?}"), + }; + candle_metal_kernels::call_affine( + &device.device, + &command_buffer, + &device.kernels, + name, + el, + &self.buffer, + &buffer, + mul as f32, + add as f32, + ) + .map_err(MetalError::from)?; + } else { + let name = match self.dtype { + DType::F32 => "affine_f32_strided", + DType::F16 => "affine_f16_strided", + dtype => crate::bail!("Affine {dtype:?}"), + }; + candle_metal_kernels::call_affine_strided( + &device.device, + &command_buffer, + &device.kernels, + name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * dtype.size_in_bytes(), + &buffer, + mul as f32, + add as f32, + ) + .map_err(MetalError::from)?; } - - let mut buffer = device.new_buffer(el, self.dtype); - let command_buffer = self.device.command_queue.new_command_buffer(); - candle_metal_kernels::call_affine( - &device.device, - &command_buffer, - &device.kernels, - el, - &self.buffer, - &mut buffer, - mul as f32, - add as f32, - ) - .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); - return Ok(Self { - buffer, - device: device.clone(), - dtype, - }); + Ok(Self::new(buffer, device.clone(), dtype)) } - fn powf(&self, _: &Layout, _: f64) -> Result<Self> { - crate::bail!("powf metal") + fn powf(&self, layout: &Layout, pow: f64) -> Result<Self> { + let device = self.device().clone(); + + let shape = layout.shape(); + let el = shape.elem_count(); + let dtype = self.dtype; + + let buffer = device.new_buffer(el, self.dtype, "powf")?; + let command_buffer = self.device.command_buffer()?; + if layout.is_contiguous() && layout.start_offset() == 0 { + let name = match self.dtype { + DType::F32 => "powf_f32", + DType::F16 => "powf_f16", + dtype => crate::bail!("Powf {dtype:?}"), + }; + candle_metal_kernels::call_powf( + &device.device, + &command_buffer, + &device.kernels, + name, + el, + &self.buffer, + &buffer, + pow as f32, + ) + .map_err(MetalError::from)?; + } else { + let name = match self.dtype { + DType::F32 => "powf_f32_strided", + DType::F16 => "powf_f16_strided", + dtype => crate::bail!("Powf {dtype:?}"), + }; + candle_metal_kernels::call_powf_strided( + &device.device, + &command_buffer, + &device.kernels, + name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * dtype.size_in_bytes(), + &buffer, + pow as f32, + ) + .map_err(MetalError::from)?; + } + Ok(Self::new(buffer, device.clone(), dtype)) } - fn elu(&self, _: &Layout, _: f64) -> Result<Self> { - crate::bail!("elu metal") + fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> { + let device = self.device().clone(); + + let shape = layout.shape(); + let el = shape.elem_count(); + let dtype = self.dtype; + + let buffer = device.new_buffer(el, self.dtype, "elu")?; + let command_buffer = self.device.command_buffer()?; + if layout.is_contiguous() && layout.start_offset() == 0 { + let name = match self.dtype { + DType::F32 => "elu_f32", + DType::F16 => "elu_f16", + dtype => crate::bail!("Powf {dtype:?}"), + }; + candle_metal_kernels::call_elu( + &device.device, + &command_buffer, + &device.kernels, + name, + el, + &self.buffer, + &buffer, + alpha as f32, + ) + .map_err(MetalError::from)?; + } else { + let name = match self.dtype { + DType::F32 => "elu_f32_strided", + DType::F16 => "elu_f16_strided", + dtype => crate::bail!("Powf {dtype:?}"), + }; + candle_metal_kernels::call_elu_strided( + &device.device, + &command_buffer, + &device.kernels, + name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * dtype.size_in_bytes(), + &buffer, + alpha as f32, + ) + .map_err(MetalError::from)?; + } + Ok(Self::new(buffer, device.clone(), dtype)) } fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> { - if !(sum_dims.len() == 1 - && sum_dims[0] == layout.shape().rank() - 1 - && layout.is_contiguous() - && layout.start_offset() == 0) - { - crate::bail!("Non contiguous reduce op not supported yet"); - } let device = self.device.clone(); let src_stride = layout.stride(); let src_dims = layout.shape().dims(); - let src_el: usize = src_dims.iter().product(); // Source dims and strides with the sum dims at the end. let mut dims = vec![]; let mut stride = vec![]; @@ -191,53 +512,77 @@ impl BackendStorage for MetalStorage { // The reduction loop requires the shared array to be properly initialized and for // this we want the number of threads to be a power of two. let (name, check_empty, return_index) = match (op, self.dtype) { - (ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false), - (ReduceOp::Min, DType::F32) => ("fast_min_float", true, false), - (ReduceOp::Max, DType::F32) => ("fast_max_float", true, false), - (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true), - (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true), - _ => crate::bail!("Reduce op for non float"), + (ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false), + (ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false), + (ReduceOp::Max, DType::F32) => ("fast_max_f32_strided", true, false), + (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32_strided", true, true), + (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32_strided", true, true), + (ReduceOp::Sum, DType::U32) => ("fast_sum_u32_strided", false, false), + (ReduceOp::Min, DType::U32) => ("fast_min_u32_strided", true, false), + (ReduceOp::Max, DType::U32) => ("fast_max_u32_strided", true, false), + (ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32_strided", true, true), + (ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32_strided", true, true), + (ReduceOp::Sum, DType::F16) => ("fast_sum_f16_strided", false, false), + (ReduceOp::Min, DType::F16) => ("fast_min_f16_strided", true, false), + (ReduceOp::Max, DType::F16) => ("fast_max_f16_strided", true, false), + (ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16_strided", true, true), + (ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16_strided", true, true), + (ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16_strided", false, false), + (ReduceOp::Min, DType::BF16) => ("fast_min_bf16_strided", true, false), + (ReduceOp::Max, DType::BF16) => ("fast_max_bf16_strided", true, false), + (ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16_strided", true, true), + (ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16_strided", true, true), + (k, dtype) => crate::bail!("Reduce op for non float {k:?} {dtype:?}"), }; if check_empty && layout.shape().elem_count() == 0 { Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? } let dtype = if return_index { DType::U32 } else { self.dtype }; - let mut buffer = device.new_buffer(dst_el, dtype); - let command_buffer = self.device.command_queue.new_command_buffer(); - candle_metal_kernels::call_reduce_contiguous( + let buffer = device.new_buffer(dst_el, dtype, "reduce")?; + let command_buffer = self.device.command_buffer()?; + candle_metal_kernels::call_reduce_strided( &device.device, &command_buffer, &device.kernels, name, - src_el, + &dims, + &stride, dst_el, &self.buffer, - &mut buffer, + layout.start_offset() * self.dtype.size_in_bytes(), + &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); - Ok(Self { - buffer, - device, - dtype, - }) + Ok(Self::new(buffer, device, dtype)) } - fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> { - crate::bail!("cmp metal") + fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> { + let name = match op { + CmpOp::Eq => "eq", + CmpOp::Ne => "ne", + CmpOp::Le => "le", + CmpOp::Ge => "ge", + CmpOp::Lt => "lt", + CmpOp::Gt => "gt", + }; + self.binary(name, rhs, lhs_l, rhs_l) } fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> { let device = self.device(); let shape = layout.shape(); let el_count = shape.elem_count(); - let mut buffer = device.new_buffer(el_count, dtype); - let command_buffer = device.command_queue.new_command_buffer(); - if layout.is_contiguous() { + let buffer = device.new_buffer(el_count, dtype, "todtype")?; + let command_buffer = device.command_buffer()?; + if layout.is_contiguous() && layout.start_offset() == 0 { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32", + (DType::U32, DType::U8) => "cast_u32_u8", + (DType::U8, DType::U32) => "cast_u8_u32", + (DType::U8, DType::F32) => "cast_u8_f32", + (DType::F32, DType::F16) => "cast_f32_f16", + (DType::F16, DType::F32) => "cast_f16_f32", (left, right) => crate::bail!("to dtype {left:?} - {right:?}"), }; candle_metal_kernels::call_cast_contiguous( @@ -247,24 +592,35 @@ impl BackendStorage for MetalStorage { kernel_name, el_count, &self.buffer, - &mut buffer, + layout.start_offset() * self.dtype.size_in_bytes(), + &buffer, ) .map_err(MetalError::from)?; } else { - crate::bail!( - "TODO Implement the kernel calling cast {:?}-{:?}", - self.dtype, - dtype - ); + let kernel_name = match (self.dtype, dtype) { + (DType::U32, DType::F32) => "cast_u32_f32_strided", + (DType::U32, DType::U8) => "cast_u32_u8_strided", + (DType::U8, DType::U32) => "cast_u8_u32_strided", + (DType::U8, DType::F32) => "cast_u8_f32_strided", + (DType::F32, DType::F16) => "cast_f32_f16_strided", + (DType::F16, DType::F32) => "cast_f16_f32_strided", + (left, right) => crate::bail!("to dtype {left:?} - {right:?}"), + }; + candle_metal_kernels::call_cast_strided( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * self.dtype.size_in_bytes(), + &buffer, + ) + .map_err(MetalError::from)?; } - - command_buffer.commit(); - command_buffer.wait_until_completed(); - Ok(Self { - buffer, - device: device.clone(), - dtype, - }) + command_buffer.set_label("to_dtype"); + Ok(Self::new(buffer, device.clone(), dtype)) } fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> { @@ -272,8 +628,9 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let shape = layout.shape(); let el_count = shape.elem_count(); - let mut buffer = device.new_buffer(el_count, dtype); - let command_buffer = device.command_queue.new_command_buffer(); + let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?; + let command_buffer = device.command_buffer()?; + command_buffer.set_label(B::KERNEL); if layout.is_contiguous() && layout.start_offset() == 0 { use candle_metal_kernels::unary::contiguous; @@ -285,6 +642,27 @@ impl BackendStorage for MetalStorage { ("uneg", DType::F32) => contiguous::neg::FLOAT, ("uexp", DType::F32) => contiguous::exp::FLOAT, ("ulog", DType::F32) => contiguous::log::FLOAT, + ("ugelu", DType::F32) => contiguous::gelu::FLOAT, + ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT, + ("uerf", DType::F32) => contiguous::erf::FLOAT, + ("uceil", DType::F32) => contiguous::ceil::FLOAT, + ("ufloor", DType::F32) => contiguous::floor::FLOAT, + ("uround", DType::F32) => contiguous::round::FLOAT, + ("utanh", DType::F32) => contiguous::tanh::FLOAT, + ("ucos", DType::F16) => contiguous::cos::HALF, + ("usin", DType::F16) => contiguous::sin::HALF, + ("usqr", DType::F16) => contiguous::sqr::HALF, + ("usqrt", DType::F16) => contiguous::sqrt::HALF, + ("uneg", DType::F16) => contiguous::neg::HALF, + ("uexp", DType::F16) => contiguous::exp::HALF, + ("ulog", DType::F16) => contiguous::log::HALF, + ("ugelu", DType::F16) => contiguous::gelu::HALF, + ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF, + ("uerf", DType::F16) => contiguous::erf::HALF, + ("uceil", DType::F16) => contiguous::ceil::HALF, + ("ufloor", DType::F16) => contiguous::floor::HALF, + ("uround", DType::F16) => contiguous::round::HALF, + ("utanh", DType::F16) => contiguous::tanh::HALF, (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_unary_contiguous( @@ -294,95 +672,64 @@ impl BackendStorage for MetalStorage { kernel_name, el_count, &self.buffer, - &mut buffer, - ) - .map_err(MetalError::from)?; - } else { - crate::bail!("TODO Implement the kernel calling {}", B::KERNEL); - } - command_buffer.commit(); - command_buffer.wait_until_completed(); - - Ok(Self { - buffer, - device: device.clone(), - dtype, - }) - } - - fn binary_impl<B: BinaryOpT>( - &self, - rhs: &Self, - lhs_l: &Layout, - rhs_l: &Layout, - ) -> Result<Self> { - let device = self.device(); - let dtype = self.dtype; - let shape = lhs_l.shape(); - let el_count = shape.elem_count(); - let mut buffer = device.new_buffer(el_count, dtype); - let command_buffer = device.command_queue.new_command_buffer(); - if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) - && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) - { - use candle_metal_kernels::binary::contiguous; - - let kernel_name = match (B::KERNEL, dtype) { - ("add", DType::F32) => contiguous::add::FLOAT, - ("badd", DType::F32) => contiguous::add::FLOAT, - ("sub", DType::F32) => contiguous::sub::FLOAT, - ("bsub", DType::F32) => contiguous::sub::FLOAT, - ("mul", DType::F32) => contiguous::mul::FLOAT, - ("bmul", DType::F32) => contiguous::mul::FLOAT, - ("div", DType::F32) => contiguous::div::FLOAT, - ("bdiv", DType::F32) => contiguous::div::FLOAT, - (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), - }; - candle_metal_kernels::call_binary_contiguous( - &device.device, - &command_buffer, - &device.kernels, - kernel_name, - el_count, - &self.buffer, - &rhs.buffer, - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; } else { - use candle_metal_kernels::binary::strided; - + use candle_metal_kernels::unary::strided; let kernel_name = match (B::KERNEL, dtype) { - ("badd", DType::F32) => strided::add::FLOAT, - ("bsub", DType::F32) => strided::sub::FLOAT, - ("bmul", DType::F32) => strided::mul::FLOAT, - ("bdiv", DType::F32) => strided::div::FLOAT, + ("ucos", DType::F32) => strided::cos::FLOAT, + ("usin", DType::F32) => strided::sin::FLOAT, + ("usqr", DType::F32) => strided::sqr::FLOAT, + ("usqrt", DType::F32) => strided::sqrt::FLOAT, + ("uneg", DType::F32) => strided::neg::FLOAT, + ("uexp", DType::F32) => strided::exp::FLOAT, + ("ulog", DType::F32) => strided::log::FLOAT, + ("ugelu", DType::F32) => strided::gelu::FLOAT, + ("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT, + ("uerf", DType::F32) => strided::erf::FLOAT, + ("uceil", DType::F32) => strided::ceil::FLOAT, + ("ufloor", DType::F32) => strided::floor::FLOAT, + ("uround", DType::F32) => strided::round::FLOAT, + ("ucos", DType::F16) => strided::cos::HALF, + ("usin", DType::F16) => strided::sin::HALF, + ("usqr", DType::F16) => strided::sqr::HALF, + ("usqrt", DType::F16) => strided::sqrt::HALF, + ("uneg", DType::F16) => strided::neg::HALF, + ("uexp", DType::F16) => strided::exp::HALF, + ("ulog", DType::F16) => strided::log::HALF, + ("ugelu", DType::F16) => strided::gelu::HALF, + ("ugelu_erf", DType::F16) => strided::gelu_erf::HALF, + ("uerf", DType::F16) => strided::erf::HALF, + ("uceil", DType::F16) => strided::ceil::HALF, + ("ufloor", DType::F16) => strided::floor::HALF, + ("uround", DType::F16) => strided::round::HALF, (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; - candle_metal_kernels::call_binary_strided( + candle_metal_kernels::call_unary_strided( &device.device, &command_buffer, &device.kernels, kernel_name, - lhs_l.dims(), + layout.dims(), &self.buffer, - &lhs_l.stride(), - lhs_l.start_offset() * self.dtype.size_in_bytes(), - &rhs.buffer, - &rhs_l.stride(), - rhs_l.start_offset() * rhs.dtype.size_in_bytes(), - &mut buffer, + layout.stride(), + layout.start_offset() * self.dtype.size_in_bytes(), + &buffer, + 0, ) .map_err(MetalError::from)?; } - command_buffer.commit(); - command_buffer.wait_until_completed(); + Ok(Self::new(buffer, device.clone(), dtype)) + } - Ok(Self { - buffer, - device: device.clone(), - dtype, - }) + fn binary_impl<B: BinaryOpT>( + &self, + rhs: &Self, + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result<Self> { + self.binary(B::KERNEL, rhs, lhs_l, rhs_l) } fn where_cond( @@ -398,14 +745,26 @@ impl BackendStorage for MetalStorage { let dims = shape.dims(); let el = shape.elem_count(); let dtype = t.dtype; - let mut buffer = self.device.new_buffer(el, dtype); - let command_buffer = self.device.command_queue.new_command_buffer(); + let buffer = self.device.new_buffer(el, dtype, "where")?; + let command_buffer = self.device.command_buffer()?; + if t.dtype() != f.dtype() { + crate::bail!( + "Invalid where: different dtypes for values {:?} != {:?}", + t.dtype(), + f.dtype() + ); + } + let name = match (self.dtype, t.dtype()) { + (DType::U8, DType::F32) => "where_u8_f32", + (DType::U8, DType::F16) => "where_u8_f16", + (left, right) => crate::bail!("where {left:?} - {right:?} not implemented"), + }; candle_metal_kernels::call_where_cond_strided( &device.device, &command_buffer, &device.kernels, - "where_u8_f32", - &dims, + name, + dims, &self.buffer, ( layout.stride(), @@ -415,16 +774,10 @@ impl BackendStorage for MetalStorage { (&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()), &f.buffer, (&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()), - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); - Ok(Self { - buffer, - device, - dtype, - }) + Ok(Self::new(buffer, device, dtype)) } fn conv1d( @@ -483,20 +836,84 @@ impl BackendStorage for MetalStorage { crate::bail!("upsample_nearest2d metal") } - fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> { - crate::bail!("gather metal") + fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> { + let (ids_o1, _) = match ids_l.contiguous_offsets() { + Some(o12) => o12, + None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?, + }; + let ids_el = ids_l.dims()[dim]; + let dst_el = ids_l.shape().elem_count(); + let dtype = self.dtype; + let device = self.device(); + let buffer = device.new_buffer(dst_el, dtype, "index_select")?; + let name = match (ids.dtype, self.dtype) { + (DType::U32, DType::F32) => "gather_u32_f32", + (DType::U32, DType::F16) => "gather_u32_f16", + (left, right) => crate::bail!("gather metal {left:?} {right:?} not implemented"), + }; + let command_buffer = self.device.command_buffer()?; + candle_metal_kernels::call_gather( + &device.device, + &command_buffer, + &self.device.kernels, + name, + src_l.dims(), + ids_el, + dim, + &self.buffer, + src_l.start_offset() * dtype.size_in_bytes(), + &ids.buffer, + ids_o1 * ids.dtype.size_in_bytes(), + &buffer, + ) + .map_err(MetalError::from)?; + Ok(Self::new(buffer, device.clone(), dtype)) } fn scatter_add( &self, - _: &Layout, - _: &Self, - _: &Layout, - _: &Self, - _: &Layout, - _: usize, + l: &Layout, + ids: &Self, + ids_l: &Layout, + src: &Self, + src_l: &Layout, + dim: usize, ) -> Result<Self> { - crate::bail!("scatter_add metal") + let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?; + self.copy_strided_src(&mut acc, 0, l)?; + let (ids_offset, _) = match ids_l.contiguous_offsets() { + Some(o12) => o12, + None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?, + }; + let src_offset = match src_l.contiguous_offsets() { + Some((o1, _)) => o1, + None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?, + }; + let name = match (ids.dtype, self.dtype) { + (DType::U32, DType::F32) => "sa_u32_f32", + _ => Err(MetalError::UnexpectedDType { + msg: "scatter-add ids should be u8/u32/i64", + expected: DType::U32, + got: ids.dtype(), + })?, + }; + let command_buffer = self.device.command_buffer()?; + candle_metal_kernels::call_scatter_add( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + src_l.dims(), + l.dims(), + dim, + &src.buffer, + src_offset * src.dtype.size_in_bytes(), + &ids.buffer, + ids_offset * ids.dtype.size_in_bytes(), + &acc.buffer, + ) + .map_err(MetalError::from)?; + Ok(acc) } fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> { @@ -513,12 +930,13 @@ impl BackendStorage for MetalStorage { let dst_el = ids_el * left_size * right_size; let dtype = self.dtype; let device = self.device(); - let mut buffer = device.new_buffer(dst_el, dtype); + let buffer = device.new_buffer(dst_el, dtype, "index_select")?; let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "is_u32_f32", + (DType::U32, DType::F16) => "is_u32_f16", (left, right) => crate::bail!("index select metal {left:?} {right:?}"), }; - let command_buffer = self.device.command_queue.new_command_buffer(); + let command_buffer = self.device.command_buffer()?; candle_metal_kernels::call_index_select( &device.device, &command_buffer, @@ -529,30 +947,58 @@ impl BackendStorage for MetalStorage { dim, &self.buffer, &ids.buffer, - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); - Ok(Self { - buffer, - device: device.clone(), - dtype, - }) + Ok(Self::new(buffer, device.clone(), dtype)) } fn index_add( &self, - _: &Layout, - _: &Self, - _: &Layout, - _: &Self, - _: &Layout, - _: usize, + l: &Layout, + ids: &Self, + ids_l: &Layout, + src: &Self, + src_l: &Layout, + dim: usize, ) -> Result<Self> { - crate::bail!("index_add metal") + let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?; + self.copy_strided_src(&mut acc, 0, l)?; + let (ids_offset, _) = match ids_l.contiguous_offsets() { + Some(o12) => o12, + None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + let src_offset = match src_l.contiguous_offsets() { + Some((o1, _)) => o1, + None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + let name = match (ids.dtype, self.dtype) { + (DType::U32, DType::F32) => "ia_u32_f32", + _ => Err(MetalError::UnexpectedDType { + msg: "index-add ids should be u8/u32/i64", + expected: DType::U32, + got: ids.dtype(), + })?, + }; + let command_buffer = self.device.command_buffer()?; + candle_metal_kernels::call_index_add( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + src_l.dims(), + l.dims(), + ids_l.dims(), + dim, + &src.buffer, + src_offset * src.dtype.size_in_bytes(), + &ids.buffer, + ids_offset * ids.dtype.size_in_bytes(), + &acc.buffer, + ) + .map_err(MetalError::from)?; + Ok(acc) } - fn matmul( &self, rhs: &Self, @@ -560,147 +1006,81 @@ impl BackendStorage for MetalStorage { lhs_l: &Layout, rhs_l: &Layout, ) -> Result<Self> { - // Create descriptors - use metal::mps::matrix::*; - let type_id = metal::mps::MPS_FLOATBIT_ENCODING | 32; - let size = core::mem::size_of::<f32>() as NSUInteger; - - let elem_count = b * m * n; - - let lhs_stride = lhs_l.stride(); - let rhs_stride = rhs_l.stride(); - let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; - let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; - let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; - let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; - // The a tensor has dims batching, k, n (rhs) - let transpose_left = if lhs_m1 == 1 && lhs_m2 == k { - false - } else if lhs_m1 == m && lhs_m2 == 1 { - true - } else { - Err(MetalError::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - mnk: (m, n, k), - })? - }; - let transpose_right = if rhs_m1 == 1 && rhs_m2 == n { - false - } else if rhs_m1 == k && rhs_m2 == 1 { - true - } else { - Err(MetalError::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - mnk: (m, n, k), - })? - }; - - let b = b as NSUInteger; - let m = m as NSUInteger; - let n = n as NSUInteger; - let k = k as NSUInteger; - - let left_descriptor = if transpose_left { - MatrixDescriptor::init_single(k, m, m * size, type_id) - } else { - MatrixDescriptor::init_single(m, k, k * size, type_id) - }; - let right_descriptor = if transpose_right { - MatrixDescriptor::init_single(n, k, k * size, type_id) - } else { - MatrixDescriptor::init_single(k, n, n * size, type_id) + let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?; + let name = match self.dtype { + DType::F32 => "sgemm", + DType::F16 => "hgemm", + dtype => { + return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into()) + } }; - let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id); - - // Create matrix objects - let left_matrix = Matrix::init_with_buffer_descriptor(&self.buffer, 0, &left_descriptor) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; - let right_matrix = Matrix::init_with_buffer_descriptor(&rhs.buffer, 0, &right_descriptor) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; - - let out_buffer = self.device.new_buffer(elem_count, self.dtype); - let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, 0, &result_descriptor) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; - - let alpha = 1.0f64; - let beta = 0.0f64; - // Create kernel - let matrix_multiplication = MatrixMultiplication::init( - &self.device, - transpose_left, - transpose_right, - m, - n, - k, - alpha, - beta, - ) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; - - matrix_multiplication.set_batch_size(b); - - // Encode kernel to command buffer - let command_buffer = self.device.command_queue.new_command_buffer(); - matrix_multiplication.encode_to_command_buffer( - command_buffer, - &left_matrix, - &right_matrix, - &result_matrix, - ); - command_buffer.commit(); - command_buffer.wait_until_completed(); - - Ok(Self { - buffer: out_buffer, - device: self.device.clone(), - dtype: self.dtype(), - }) - } - fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { - let src_shape = src_l.shape(); - let el_count = src_shape.elem_count(); - if el_count == 0 { - return Ok(()); - } - let command_buffer = self.device.command_queue.new_command_buffer(); - let kernel_name = match self.dtype { - DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, - DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, - DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, - dtype => crate::bail!("copy_strided not implemented for {dtype:?}"), - }; - candle_metal_kernels::call_unary_strided( + let command_buffer = self.device.command_buffer()?; + command_buffer.set_label("matmul"); + candle_metal_kernels::call_gemm( &self.device.device, &command_buffer, &self.device.kernels, - kernel_name, - src_l.dims(), + name, + (b, m, n, k), + lhs_l.stride(), + lhs_l.start_offset() * self.dtype.size_in_bytes(), &self.buffer, - &src_l.stride(), - src_l.start_offset() * self.dtype.size_in_bytes(), - &mut dst.buffer, - dst_offset, + rhs_l.stride(), + rhs_l.start_offset() * rhs.dtype.size_in_bytes(), + &rhs.buffer, + &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); + Ok(Self::new(buffer, self.device.clone(), self.dtype())) + } + + fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { + let command_buffer = self.device.command_buffer()?; + if src_l.is_contiguous() && self.dtype == dst.dtype() { + command_buffer.set_label("copy_contiguous"); + let blit = command_buffer.new_blit_command_encoder(); + blit.set_label("copy_contiguous"); + let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger; + let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger; + let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger; + blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length); + blit.end_encoding(); + } else { + let src_shape = src_l.shape(); + let el_count = src_shape.elem_count(); + if el_count == 0 { + return Ok(()); + } + let kernel_name = match self.dtype { + DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, + DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, + DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, + DType::U32 => candle_metal_kernels::unary::strided::copy::U32, + DType::U8 => candle_metal_kernels::unary::strided::copy::U8, + dtype => crate::bail!("copy_strided not implemented for {dtype:?}"), + }; + candle_metal_kernels::call_unary_strided( + &self.device.device, + &command_buffer, + &self.device.kernels, + kernel_name, + src_l.dims(), + &self.buffer, + src_l.stride(), + src_l.start_offset() * self.dtype.size_in_bytes(), + &dst.buffer, + dst_offset * dst.dtype.size_in_bytes(), + ) + .map_err(MetalError::from)?; + command_buffer.set_label("copy_strided"); + } Ok(()) } } impl MetalStorage { - pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self { + pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: DType) -> Self { Self { buffer, device, @@ -711,6 +1091,111 @@ impl MetalStorage { pub fn buffer(&self) -> &Buffer { &self.buffer } + + pub fn binary( + &self, + op: &'static str, + rhs: &Self, + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result<Self> { + let device = self.device(); + let shape = lhs_l.shape(); + let el_count = shape.elem_count(); + let command_buffer = device.command_buffer()?; + let (buffer, dtype) = if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) + && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) + && &op[..1] != "b" + { + use candle_metal_kernels::binary::contiguous; + + let (kernel_name, dtype) = match (op, self.dtype) { + ("add", DType::F32) => (contiguous::add::FLOAT, self.dtype), + ("sub", DType::F32) => (contiguous::sub::FLOAT, self.dtype), + ("mul", DType::F32) => (contiguous::mul::FLOAT, self.dtype), + ("div", DType::F32) => (contiguous::div::FLOAT, self.dtype), + ("eq", DType::F32) => (contiguous::eq::FLOAT, DType::U8), + ("ne", DType::F32) => (contiguous::ne::FLOAT, DType::U8), + ("le", DType::F32) => (contiguous::le::FLOAT, DType::U8), + ("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8), + ("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8), + ("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8), + ("add", DType::F16) => (contiguous::add::HALF, self.dtype), + ("sub", DType::F16) => (contiguous::sub::HALF, self.dtype), + ("mul", DType::F16) => (contiguous::mul::HALF, self.dtype), + ("div", DType::F16) => (contiguous::div::HALF, self.dtype), + ("eq", DType::F16) => (contiguous::eq::HALF, DType::U8), + ("ne", DType::F16) => (contiguous::ne::HALF, DType::U8), + ("le", DType::F16) => (contiguous::le::HALF, DType::U8), + ("lt", DType::F16) => (contiguous::lt::HALF, DType::U8), + ("ge", DType::F16) => (contiguous::ge::HALF, DType::U8), + ("gt", DType::F16) => (contiguous::gt::HALF, DType::U8), + (name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"), + }; + let buffer = device.new_buffer(el_count, dtype, op)?; + candle_metal_kernels::call_binary_contiguous( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + &self.buffer, + &rhs.buffer, + &buffer, + ) + .map_err(MetalError::from)?; + (buffer, dtype) + } else { + use candle_metal_kernels::binary::strided; + + let (kernel_name, dtype) = match (op, self.dtype) { + ("badd", DType::F32) => (strided::add::FLOAT, self.dtype), + ("bsub", DType::F32) => (strided::sub::FLOAT, self.dtype), + ("bmul", DType::F32) => (strided::mul::FLOAT, self.dtype), + ("bdiv", DType::F32) => (strided::div::FLOAT, self.dtype), + ("bminimum", DType::F32) => (strided::min::FLOAT, self.dtype), + ("bmaximum", DType::F32) => (strided::max::FLOAT, self.dtype), + ("eq", DType::F32) => (strided::eq::FLOAT, DType::U8), + ("ne", DType::F32) => (strided::ne::FLOAT, DType::U8), + ("le", DType::F32) => (strided::le::FLOAT, DType::U8), + ("lt", DType::F32) => (strided::lt::FLOAT, DType::U8), + ("ge", DType::F32) => (strided::ge::FLOAT, DType::U8), + ("gt", DType::F32) => (strided::gt::FLOAT, DType::U8), + ("badd", DType::F16) => (strided::add::HALF, self.dtype), + ("bsub", DType::F16) => (strided::sub::HALF, self.dtype), + ("bmul", DType::F16) => (strided::mul::HALF, self.dtype), + ("bdiv", DType::F16) => (strided::div::HALF, self.dtype), + ("bminimum", DType::F16) => (strided::min::HALF, self.dtype), + ("bmaximum", DType::F16) => (strided::max::HALF, self.dtype), + ("eq", DType::F16) => (strided::eq::HALF, DType::U8), + ("ne", DType::F16) => (strided::ne::HALF, DType::U8), + ("le", DType::F16) => (strided::le::HALF, DType::U8), + ("lt", DType::F16) => (strided::lt::HALF, DType::U8), + ("ge", DType::F16) => (strided::ge::HALF, DType::U8), + ("gt", DType::F16) => (strided::gt::HALF, DType::U8), + (name, dtype) => crate::bail!("Binary strided {name} - {dtype:?} not implemented"), + }; + let buffer = device.new_buffer(el_count, dtype, op)?; + candle_metal_kernels::call_binary_strided( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + lhs_l.dims(), + &self.buffer, + lhs_l.stride(), + lhs_l.start_offset() * self.dtype.size_in_bytes(), + &rhs.buffer, + rhs_l.stride(), + rhs_l.start_offset() * rhs.dtype.size_in_bytes(), + &buffer, + ) + .map_err(MetalError::from)?; + (buffer, dtype) + }; + command_buffer.set_label("binary"); + Ok(Self::new(buffer, device.clone(), dtype)) + } } impl BackendDevice for MetalDevice { @@ -718,12 +1203,26 @@ impl BackendDevice for MetalDevice { fn new(ordinal: usize) -> Result<Self> { let device = metal::Device::all().swap_remove(ordinal); - let command_queue = device.new_command_queue(); - let kernels = Arc::new(Kernels::new()); + let command_buffer = command_queue.new_command_buffer().to_owned(); + command_buffer.enqueue(); + let command_buffer = Arc::new(RwLock::new(command_buffer)); + let command_buffer_index = Arc::new(RwLock::new(0)); + let fence = device.new_fence(); + let kernels = Arc::new(Kernels::new(fence.clone())); + let buffers = Arc::new(RwLock::new(HashMap::new())); + let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") { + Ok(val) => val.parse()?, + _ => 20, + }; Ok(Self { device, + fence, command_queue, + command_buffer, + command_buffer_index, + compute_per_buffer, + buffers, kernels, }) } @@ -743,9 +1242,22 @@ impl BackendDevice for MetalDevice { } fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> { - // TODO Is there a faster way ? - let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?; - self.storage_from_cpu_storage(&cpu_storage) + let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros")?; + let command_buffer = self.command_buffer()?; + command_buffer.set_label("zeros"); + let blit = command_buffer.new_blit_command_encoder(); + blit.wait_for_fence(&self.fence); + blit.fill_buffer( + &buffer, + metal::NSRange { + location: 0, + length: buffer.length(), + }, + 0, + ); + blit.update_fence(&self.fence); + blit.end_encoding(); + Ok(MetalStorage::new(buffer, self.clone(), dtype)) } fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> { @@ -755,49 +1267,16 @@ impl BackendDevice for MetalDevice { } fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> { - let option = metal::MTLResourceOptions::StorageModeManaged; let buffer = match storage { - CpuStorage::U8(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<u8>()) as NSUInteger, - option, - ), - CpuStorage::U32(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<u32>()) as NSUInteger, - option, - ), - CpuStorage::I64(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<i64>()) as NSUInteger, - option, - ), - CpuStorage::BF16(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<bf16>()) as NSUInteger, - option, - ), - CpuStorage::F16(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<f16>()) as NSUInteger, - option, - ), - CpuStorage::F32(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<f32>()) as NSUInteger, - option, - ), - CpuStorage::F64(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<f64>()) as NSUInteger, - option, - ), - }; - Ok(Self::Storage { - buffer, - device: self.clone(), - dtype: storage.dtype(), - }) + CpuStorage::U8(storage) => self.new_buffer_with_data(storage), + CpuStorage::U32(storage) => self.new_buffer_with_data(storage), + CpuStorage::I64(storage) => self.new_buffer_with_data(storage), + CpuStorage::BF16(storage) => self.new_buffer_with_data(storage), + CpuStorage::F16(storage) => self.new_buffer_with_data(storage), + CpuStorage::F32(storage) => self.new_buffer_with_data(storage), + CpuStorage::F64(storage) => self.new_buffer_with_data(storage), + }?; + Ok(Self::Storage::new(buffer, self.clone(), storage.dtype())) } fn rand_uniform( @@ -824,3 +1303,10 @@ impl BackendDevice for MetalDevice { self.storage_from_cpu_storage(&cpu_storage) } } + +fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> { + let ptr = buffer.contents() as *const T; + assert!(!ptr.is_null()); + let slice = unsafe { std::slice::from_raw_parts(ptr, n) }; + slice.to_vec() +} diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index e6e7b415..f15f8c1c 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1877,10 +1877,7 @@ impl Tensor { Storage::Metal(metal.storage_from_cpu_storage(storage)?) } (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), - (Storage::Metal(storage), Device::Cpu) => { - println!("{storage:?} - {:?}", storage.to_cpu_storage()?); - Storage::Cpu(storage.to_cpu_storage()?) - } + (Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), (Storage::Cuda(storage), Device::Cuda(cuda)) => { // TODO: Avoid passing through the cpu storage here, especially if the gpu ids // are the same. diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 508f75f5..0c4bf20e 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -57,6 +57,7 @@ flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] nccl = ["cuda", "cudarc/nccl", "dep:half"] onnx = ["candle-onnx"] +metal = ["candle/metal", "candle-nn/metal"] [[example]] name = "llama_multiprocess" diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index c0e019f4..7ab45a90 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -10,7 +10,7 @@ categories = ["science"] license = "MIT OR Apache-2.0" [dependencies] -metal = { version = "0.27.1", features = ["mps"], package="candle-metal" } +metal = { version = "0.27.0", features = ["mps"]} once_cell = "1.18.0" thiserror = "1" tracing = "0.1.37" diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index e5f0a841..4166d811 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -29,15 +29,96 @@ kernel void FN_NAME( \ if (id >= dim) { \ return; \ } \ - const TYPENAME m = TYPENAME(mul); \ - const TYPENAME a = TYPENAME(add); \ - output[id] = input[id] * m + a; \ + output[id] = TYPENAME(float(input[id]) * mul + add); \ } \ +kernel void FN_NAME##_strided( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant float &mul, \ + constant float &add, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + output[id] = TYPENAME(float(input[get_strided_index(id, num_dims, dims, strides)]) * mul + add); \ +} + +#define POWF(FN_NAME, TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + constant float &mul, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + output[id] = TYPENAME(pow(input[id], TYPENAME(mul))); \ +} \ +kernel void FN_NAME##_strided( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant float &mul, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + output[id] = TYPENAME(pow(input[get_strided_index(id, num_dims, dims, strides)], TYPENAME(mul))); \ +} + +#define ELU(FN_NAME, TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + constant float &mul, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + const TYPENAME x = input[id]; \ + output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \ +} \ +kernel void FN_NAME##_strided( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant float &mul, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + const TYPENAME x = input[get_strided_index(id, num_dims, dims, strides)]; \ + output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \ +} \ + -AFFINE(affine_float, float) -AFFINE(affine_half, half) +AFFINE(affine_f32, float) +AFFINE(affine_f16, half) +POWF(powf_f32, float) +POWF(powf_f16, half) +ELU(elu_f32, float) +ELU(elu_f16, half) #if __METAL_VERSION__ >= 310 -AFFINE(affine_bfloat, bfloat); +AFFINE(affine_bf16, bfloat); +POWF(powf_bf16, bfloat); +ELU(elu_bf16, bfloat); #endif diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index f18cdbb0..8c3b4a8c 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -1,5 +1,8 @@ #include <metal_stdlib> +#define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) + METAL_FUNC uint get_strided_index( uint idx, constant size_t &num_dims, @@ -22,15 +25,15 @@ kernel void FN_NAME( \ constant size_t &dim, \ device const TYPENAME *left, \ device const TYPENAME *right, \ - device TYPENAME *output, \ - uint thread_position_in_grid [[ thread_position_in_grid ]] \ + device OUT_TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ ) { \ - if (thread_position_in_grid >= dim) { \ + if (tid >= dim) { \ return; \ } \ - TYPENAME x = left[thread_position_in_grid]; \ - TYPENAME y = right[thread_position_in_grid]; \ - output[thread_position_in_grid] = OUT_TYPENAME(FN); \ + TYPENAME x = left[tid]; \ + TYPENAME y = right[tid]; \ + output[tid] = OUT_TYPENAME(FN); \ }\ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -40,33 +43,48 @@ kernel void FN_NAME_STRIDED( \ constant size_t *right_strides, \ device const TYPENAME *left, \ device const TYPENAME *right, \ - device TYPENAME *output, \ - uint thread_position_in_grid [[ thread_position_in_grid ]] \ + device OUT_TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ ) { \ - if (thread_position_in_grid >= dim) { \ + if (tid >= dim) { \ return; \ } \ - TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \ - TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, right_strides)]; \ - output[thread_position_in_grid] = OUT_TYPENAME(FN); \ + TYPENAME x = left[get_strided_index(tid, num_dims, dims, left_strides)]; \ + TYPENAME y = right[get_strided_index(tid, num_dims, dims, right_strides)]; \ + output[tid] = OUT_TYPENAME(FN); \ } #define BINARY_OP(FN, NAME) \ -BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \ -BINARY(FN, half, half, NAME##_half, NAME##_half_strided); +BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \ +BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); #define BFLOAT_BINARY_OP(FN, NAME) \ -BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided); +BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided); + +#define BINARY_OP_OUT(NAME, FN) \ +BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \ +BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); BINARY_OP(x + y, add) BINARY_OP(x - y, sub) BINARY_OP(x * y, mul) BINARY_OP(x / y, div) +BINARY_OP(MIN(x, y), min) +BINARY_OP(MAX(x, y), max) + +BINARY_OP_OUT(eq, x == y) +BINARY_OP_OUT(ne, x != y) +BINARY_OP_OUT(le, x <= y) +BINARY_OP_OUT(lt, x < y) +BINARY_OP_OUT(ge, x >= y) +BINARY_OP_OUT(gt, x > y) #if __METAL_VERSION__ >= 310 BFLOAT_BINARY_OP(x + y, add) BFLOAT_BINARY_OP(x - y, sub) BFLOAT_BINARY_OP(x * y, mul) BFLOAT_BINARY_OP(x / y, div) +BFLOAT_BINARY_OP(MIN(x, y), min) +BFLOAT_BINARY_OP(MAX(x, y), max) #endif diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index d1788253..8481389d 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -23,12 +23,12 @@ kernel void FN_NAME( \ constant size_t &dim, \ device const LEFT_TYPENAME *input, \ device RIGHT_TYPENAME *output, \ - uint thread_position_in_grid [[ thread_position_in_grid ]] \ + uint tid [[ thread_position_in_grid ]] \ ) { \ - if (thread_position_in_grid >= dim) { \ + if (tid >= dim) { \ return; \ } \ - output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \ + output[tid] = RIGHT_TYPENAME(input[tid]); \ } \ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -37,15 +37,20 @@ kernel void FN_NAME_STRIDED( \ constant size_t *strides, \ device const LEFT_TYPENAME *input, \ device RIGHT_TYPENAME *output, \ - uint i [[ thread_position_in_grid ]] \ + uint tid [[ thread_position_in_grid ]] \ ) { \ - if (i >= dim) { \ + if (tid >= dim) { \ return; \ } \ - output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \ + output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \ } \ -CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float) +CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float) +CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t) +CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t) +CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float) +CAST(cast_f16_f32, cast_f16_f32_strided, half, float) +CAST(cast_f32_f16, cast_f32_f16_strided, float, half) #if __METAL_VERSION__ >= 310 #endif diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 444fa322..63357428 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -1,6 +1,34 @@ #include <metal_stdlib> using namespace metal; +template<typename TYPENAME, typename INDEX_TYPENAME> +METAL_FUNC void index( + constant size_t &dst_size, + constant size_t &left_size, + constant size_t &src_dim_size, + constant size_t &right_size, + constant size_t &ids_size, + const device TYPENAME *input, + const device INDEX_TYPENAME *input_ids, + device TYPENAME *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dst_size) { + return; + } + const size_t id_i = (tid / right_size) % ids_size; + const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); + const size_t right_rank_i = tid % right_size; + const size_t left_rank_i = tid / right_size / ids_size; + /* + // Force prevent out of bounds indexing + // since there doesn't seem to be a good way to force crash + // No need to check for zero we're only allowing unsized. + */ + const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; + output[tid] = input[src_i]; +} + # define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \ kernel void NAME( \ constant size_t &dst_size, \ @@ -11,92 +39,160 @@ kernel void NAME( \ const device TYPENAME *input, \ const device INDEX_TYPENAME *input_ids, \ device TYPENAME *output, \ - uint gid [[ thread_position_in_grid ]] \ + uint tid [[ thread_position_in_grid ]] \ ) { \ - if (gid >= dst_size) { \ - return; \ - } \ - const size_t id_i = gid / right_size / left_size; \ - const size_t right_rank_i = gid % right_size; \ - const size_t left_rank_i = gid % left_size; \ - /* \ - // Force prevent out of bounds indexing \ - // since there doesn't seem to be a good way to force crash \ - // No need to check for zero we're only allowing unsized. \ - */ \ - const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \ - const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \ - output[gid] = input[src_i]; \ + index<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \ } +template<typename TYPENAME, typename INDEX_TYPENAME> +METAL_FUNC void gather( + constant size_t &dst_size, + constant size_t &left_size, + constant size_t &src_dim_size, + constant size_t &right_size, + constant size_t &ids_size, + const device TYPENAME *input, + const device INDEX_TYPENAME *input_ids, + device TYPENAME *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dst_size) { + return; + } + const INDEX_TYPENAME input_i = input_ids[tid]; + const size_t right_rank_i = tid % right_size; + const size_t left_rank_i = tid / right_size / ids_size; + const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i; + output[tid] = input[src_i]; +} -template <typename T, typename I> -void index_add( - device I *ids [[buffer(0)]], - device T *inp [[buffer(1)]], - device T *out [[buffer(2)]], - - constant uint &ids_dim_size, - constant uint &left_size, - constant uint &dst_dim_size, - constant uint &right_size, - - uint gid [[ thread_position_in_grid ]] \ -) { +# define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \ +kernel void NAME( \ + constant size_t &dst_size, \ + constant size_t &left_size, \ + constant size_t &src_dim_size, \ + constant size_t &right_size, \ + constant size_t &ids_size, \ + const device TYPENAME *input, \ + const device INDEX_TYPENAME *input_ids, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + gather<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \ +} - if (gid >= left_size * right_size) { - return; +template<typename TYPENAME, typename INDEX_TYPENAME> +METAL_FUNC void scatter_add( + constant size_t &dst_size, + constant size_t &left_size, + constant size_t &src_dim_size, + constant size_t &right_size, + constant size_t &dst_dim_size, + const device TYPENAME *input, + const device INDEX_TYPENAME *input_ids, + device TYPENAME *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dst_size) { + return; + } + const size_t right_rank_i = tid % right_size; + const size_t left_rank_i = tid / right_size; + for (unsigned int j = 0; j < src_dim_size; ++j) { + const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i; + const INDEX_TYPENAME idx = input_ids[src_i]; + const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i; + output[dst_i] += input[src_i]; } +} - const uint i = gid; - const uint pre = i / right_size; - const uint post = i % right_size; +# define SCATTER_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \ +kernel void NAME( \ + constant size_t &dst_size, \ + constant size_t &left_size, \ + constant size_t &src_dim_size, \ + constant size_t &right_size, \ + constant size_t &dst_dim_size, \ + const device TYPENAME *input, \ + const device INDEX_TYPENAME *input_ids, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + scatter_add<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \ +} - for (uint j = 0; j < ids_dim_size; j++) { - const uint idx = ids[j]; - const uint src_i = (pre * ids_dim_size + j) * right_size + post; - const uint dst_i = (pre * dst_dim_size + idx) * right_size + post; - out[dst_i] += inp[src_i]; +template<typename TYPENAME, typename INDEX_TYPENAME> +METAL_FUNC void index_add( + constant size_t &dst_size, + constant size_t &left_size, + constant size_t &src_dim_size, + constant size_t &right_size, + constant size_t &dst_dim_size, + constant size_t &ids_dim_size, + const device TYPENAME *input, + const device INDEX_TYPENAME *input_ids, + device TYPENAME *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dst_size) { + return; + } + const size_t right_rank_i = tid % right_size; + const size_t left_rank_i = tid / right_size; + for (unsigned int j = 0; j < ids_dim_size; ++j) { + const INDEX_TYPENAME idx = input_ids[j]; + const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i; + const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i; + output[dst_i] += input[src_i]; } } -#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ -kernel void FN_NAME( \ - device INDEX_TYPENAME *ids [[buffer(0)]], \ - device TYPENAME *inp [[buffer(1)]], \ - device TYPENAME *out [[buffer(2)]], \ - constant uint &ids_dim_size, \ - constant uint &left_size, \ - constant uint &dst_dim_size, \ - constant uint &right_size, \ - uint gid [[ thread_position_in_grid ]] \ -) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \ +# define INDEX_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \ +kernel void NAME( \ + constant size_t &dst_size, \ + constant size_t &left_size, \ + constant size_t &src_dim_size, \ + constant size_t &right_size, \ + constant size_t &dst_dim_size, \ + constant size_t &ids_dim_size, \ + const device TYPENAME *input, \ + const device INDEX_TYPENAME *input_ids, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + index_add<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, ids_dim_size, input, input_ids, output, tid); \ +} INDEX_OP(is_u32_f32, uint, float) +INDEX_OP(is_u32_f16, uint, half) +GATHER_OP(gather_u32_f32, uint, float) +GATHER_OP(gather_u32_f16, uint, half) +SCATTER_ADD_OP(sa_u32_f32, uint, float) +SCATTER_ADD_OP(sa_u32_f16, uint, half) #if __METAL_VERSION__ >= 310 -IA_OP(bfloat, int64_t, ia_i64_bf16) -IA_OP(bfloat, uint32_t, ia_u32_bf16) -IA_OP(bfloat, uint8_t, ia_u8_bf16) +INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat) +INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat) +INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat) #endif -IA_OP(half, uint32_t, ia_u32_f16) -IA_OP(half, uint8_t, ia_u8_f16) +INDEX_ADD_OP(ia_u32_f16, uint32_t, half) +INDEX_ADD_OP(ia_u8_f16, uint8_t, half) -IA_OP(float, int64_t, ia_i64_f32) -IA_OP(uint8_t, int64_t, ia_i64_u8) -IA_OP(int64_t, int64_t, ia_i64_i64) -IA_OP(uint32_t, int64_t, ia_i64_u32) +INDEX_ADD_OP(ia_i64_f32, int64_t, float) +INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t) +INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t) +INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t) -IA_OP(float, uint32_t, ia_u32_f32) -IA_OP(uint8_t, uint32_t, ia_u32_u8) -IA_OP(int64_t, uint32_t, ia_u32_i64) -IA_OP(uint32_t, uint32_t, ia_u32_u32) +INDEX_ADD_OP(ia_u32_f32, uint32_t, float) +INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t) +INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t) +INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t) -IA_OP(float, uint8_t, ia_u8_f32) -IA_OP(uint8_t, uint8_t, ia_u8_u8) -IA_OP(uint32_t, uint8_t, ia_u8_u32) -IA_OP(int64_t, uint8_t, ia_u8_i64) +INDEX_ADD_OP(ia_u8_f32, uint8_t, float) +INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t) +INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t) +INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 5a6bd41b..0418c96c 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,6 +1,6 @@ use metal::{ - Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor, - ComputePipelineState, Device, Function, Library, MTLSize, + Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, + Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger, }; use std::collections::HashMap; use std::ffi::c_void; @@ -13,7 +13,12 @@ const BINARY: &str = include_str!("binary.metal"); const TERNARY: &str = include_str!("ternary.metal"); const CAST: &str = include_str!("cast.metal"); const REDUCE: &str = include_str!("reduce.metal"); +const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); +/// Most kernels apply similarly across the tensors +/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the +/// actual total buffer length). +/// Then kernels can just do their op on their single point in the buffer. fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) { let size = length as u64; let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size); @@ -35,6 +40,10 @@ fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTL fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) { <P as EncoderParam>::set_param(encoder, position, data) } + +/// Helper functions to create the various objects on the compute command encoder +/// on a single line. +/// Prevents getting wrong some arguments number and mixing length and size in bytes. trait EncoderParam { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); } @@ -59,8 +68,8 @@ impl<T> EncoderParam for &[T] { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { encoder.set_bytes( position, - (core::mem::size_of::<T>() * data.len()) as u64, - data.as_ptr() as *const T as *const c_void, + core::mem::size_of_val(data) as u64, + data.as_ptr() as *const c_void, ); } } @@ -105,54 +114,59 @@ pub enum Source { Ternary, Cast, Reduce, + Mfa, } macro_rules! ops{ ($($name:ident),+) => { pub mod contiguous { - #[derive(Clone, Copy)] - pub struct Kernel(pub(crate) &'static str); - impl std::fmt::Display for Kernel { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } - } + pub struct Kernel(pub &'static str); $( pub mod $name { use super::Kernel; - pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float")); - pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half")); - pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat")); + pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32")); + pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16")); + pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16")); } )+ + pub mod copy { + use super::Kernel; + pub const FLOAT: Kernel = Kernel("copy_f32"); + pub const HALF: Kernel = Kernel("copy_f16"); + pub const BFLOAT: Kernel = Kernel("copy_bf16"); + pub const U32: Kernel = Kernel("copy_u32"); + pub const U8: Kernel = Kernel("copy_u8"); + } } pub mod strided { - #[derive(Clone, Copy)] - pub struct Kernel(pub(crate) &'static str); - impl std::fmt::Display for Kernel { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } - } + pub struct Kernel(pub &'static str); $( pub mod $name { use super::Kernel; - pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float_strided")); - pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half_strided")); - pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided")); + pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided")); + pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided")); + pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided")); } )+ + pub mod copy { + use super::Kernel; + pub const FLOAT: Kernel = Kernel("copy_f32_strided"); + pub const HALF: Kernel = Kernel("copy_f16_strided"); + pub const BFLOAT: Kernel = Kernel("copy_bf16_strided"); + pub const U32: Kernel = Kernel("copy_u32_strided"); + pub const U8: Kernel = Kernel("copy_u8_strided"); + } } }; } pub mod unary { - ops!(cos, sin, exp, sqr, sqrt, neg, copy, log); + ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh); } pub mod binary { - ops!(add, sub, mul, div); + ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt); } #[derive(thiserror::Error, Debug)] @@ -161,8 +175,18 @@ pub enum MetalKernelError { LockError(String), #[error("Error while loading library: {0}")] LoadLibraryError(String), - #[error("Error while loading function: {0}")] + #[error("Error while loading function: {0:?}")] LoadFunctionError(String), + #[error("Failed to create compute function")] + FailedToCreateComputeFunction, + #[error("Failed to create pipeline")] + FailedToCreatePipeline(String), + #[error("Invalid matmul arguments {lhs_stride:?} {rhs_stride:?} {mnk:?}")] + MatMulNonContiguous { + lhs_stride: Vec<usize>, + rhs_stride: Vec<usize>, + mnk: (usize, usize, usize), + }, } impl<T> From<std::sync::PoisonError<T>> for MetalKernelError { @@ -171,21 +195,25 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError { } } -type KernelMap<T> = HashMap<&'static str, T>; type Libraries = HashMap<Source, Library>; -type Functions = KernelMap<Function>; +type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipelineState>; -#[derive(Debug, Default)] +#[derive(Debug)] pub struct Kernels { libraries: RwLock<Libraries>, - funcs: RwLock<Functions>, + pipelines: RwLock<Pipelines>, + fence: metal::Fence, } impl Kernels { - pub fn new() -> Self { + pub fn new(fence: metal::Fence) -> Self { let libraries = RwLock::new(Libraries::new()); - let funcs = RwLock::new(Functions::new()); - Self { libraries, funcs } + let pipelines = RwLock::new(Pipelines::new()); + Self { + libraries, + pipelines, + fence, + } } fn get_library_source(&self, source: Source) -> &'static str { @@ -197,9 +225,12 @@ impl Kernels { Source::Indexing => INDEXING, Source::Cast => CAST, Source::Reduce => REDUCE, + Source::Mfa => panic!("Invalid lib"), } } + /// Load the give library from its [`source`]. + /// If this has been previously loaded it will just fetch it from cache. pub fn load_library( &self, device: &Device, @@ -209,33 +240,83 @@ impl Kernels { if let Some(lib) = libraries.get(&source) { Ok(lib.clone()) } else { - let source_content = self.get_library_source(source); - let lib = device - .new_library_with_source(source_content, &CompileOptions::new()) - .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?; + let lib = match source { + Source::Mfa => { + let source_data = MFA; + device.new_library_with_data(source_data).map_err(|e| { + MetalKernelError::LoadLibraryError(format!( + "Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}" + )) + })? + } + source => { + let source_content = self.get_library_source(source); + device + .new_library_with_source(source_content, &CompileOptions::new()) + .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))? + } + }; libraries.insert(source, lib.clone()); Ok(lib) } } - pub fn load_function( + fn load_function( &self, device: &Device, source: Source, name: &'static str, + constants: Option<FunctionConstantValues>, ) -> Result<Function, MetalKernelError> { - let mut funcs = self.funcs.write()?; - if let Some(func) = funcs.get(name) { - Ok(func.clone()) + let func = self + .load_library(device, source)? + .get_function(name, constants) + .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; + Ok(func) + } + + /// Load the give pipeline + /// loads the library from source, then gets the function [`name`] from + /// that source + fn load_pipeline_with_constants( + &self, + device: &Device, + source: Source, + name: &'static str, + constants: Option<ConstantValues>, + ) -> Result<ComputePipelineState, MetalKernelError> { + let mut pipelines = self.pipelines.write()?; + let key = (name, constants); + if let Some(pipeline) = pipelines.get(&key) { + Ok(pipeline.clone()) } else { - let func = self - .load_library(device, source)? - .get_function(name, None) - .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; - funcs.insert(name, func.clone()); - Ok(func) + let (name, constants) = key; + let func = self.load_function( + device, + source, + name, + constants.as_ref().map(|c| c.function_constant_values()), + )?; + let pipeline = device + .new_compute_pipeline_state_with_function(&func) + .map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?; + pipelines.insert((name, constants), pipeline.clone()); + + Ok(pipeline) } } + + /// Load the give pipeline + /// loads the library from source, then gets the function [`name`] from + /// that source (without constants) + pub fn load_pipeline( + &self, + device: &Device, + source: Source, + name: &'static str, + ) -> Result<ComputePipelineState, MetalKernelError> { + self.load_pipeline_with_constants(device, source, name, None) + } } #[allow(clippy::too_many_arguments)] @@ -246,25 +327,20 @@ pub fn call_unary_contiguous( kernel_name: unary::contiguous::Kernel, length: usize, input: &Buffer, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Unary, kernel_name.0)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); - + let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -279,21 +355,14 @@ pub fn call_unary_strided( input: &Buffer, strides: &[usize], offset: usize, - output: &mut Buffer, + output: &Buffer, output_offset: usize, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Unary, name.0)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -312,7 +381,10 @@ pub fn call_unary_strided( let width: usize = shape.iter().product(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -326,26 +398,23 @@ pub fn call_binary_contiguous( length: usize, left: &Buffer, right: &Buffer, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Binary, kernel_name.0)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, left, right, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.use_resource(left, metal::MTLResourceUsage::Read); + encoder.use_resource(right, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -363,21 +432,14 @@ pub fn call_binary_strided( right_input: &Buffer, right_strides: &[usize], right_offset: usize, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Binary, name.0)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?; let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); let width: usize = shape.iter().product(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -398,7 +460,11 @@ pub fn call_binary_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); + encoder.use_resource(left_input, metal::MTLResourceUsage::Read); + encoder.use_resource(right_input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -411,31 +477,68 @@ pub fn call_cast_contiguous( kernel_name: &'static str, length: usize, input: &Buffer, - output: &mut Buffer, + input_offset: usize, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Cast, kernel_name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); + let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, (input, input_offset), output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_cast_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + shape: &[usize], + input: &Buffer, + input_strides: &[usize], + input_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, input, output)); + let length: usize = shape.iter().product(); + + set_params!( + encoder, + ( + length, + shape.len(), + shape, + input_strides, + (input, input_offset), + output + ) + ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } -#[allow(clippy::too_many_arguments)] pub fn call_reduce_contiguous( device: &Device, command_buffer: &CommandBufferRef, @@ -444,24 +547,78 @@ pub fn call_reduce_contiguous( length: usize, out_length: usize, input: &Buffer, - output: &mut Buffer, + input_offset: usize, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Reduce, kernel_name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let elements_to_sum = length / out_length; - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + (length, elements_to_sum, (input, input_offset), output) + ); + + let thread_group_count = MTLSize { + width: out_length as u64, + height: 1, + depth: 1, + }; + + let width = std::cmp::min( + pipeline.max_total_threads_per_threadgroup(), + (elements_to_sum as u64 + 2 - 1) / 2, + ) + .next_power_of_two(); + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + +pub fn call_reduce_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + shape: &[usize], + strides: &[usize], + out_length: usize, + input: &Buffer, + input_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let length: usize = shape.iter().product(); + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let elements_to_sum = length / out_length; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, elements_to_sum, input, output)); + set_params!( + encoder, + ( + shape.len(), + shape, + strides, + elements_to_sum, + (input, input_offset), + output + ) + ); let thread_group_count = MTLSize { width: out_length as u64, @@ -471,7 +628,7 @@ pub fn call_reduce_contiguous( let width = std::cmp::min( pipeline.max_total_threads_per_threadgroup(), - (elements_to_sum as u64 + 2 - 1) / 2, + elements_to_sum as u64, ) .next_power_of_two(); @@ -481,7 +638,10 @@ pub fn call_reduce_contiguous( depth: 1, }; + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -495,22 +655,18 @@ pub fn call_last_softmax( length: usize, elements_to_sum: usize, input: &Buffer, - output: &mut Buffer, + input_offset: usize, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Reduce, kernel_name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); - + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, elements_to_sum, input, output)); + set_params!( + encoder, + (length, elements_to_sum, (input, input_offset), output) + ); let out_length = length / elements_to_sum; @@ -532,7 +688,10 @@ pub fn call_last_softmax( depth: 1, }; + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -542,34 +701,214 @@ pub fn call_affine( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, + name: &'static str, size: usize, input: &Buffer, - output: &mut Buffer, + output: &Buffer, + mul: f32, + add: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (size, mul, add, input, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_affine_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + input: &Buffer, + input_stride: &[usize], + input_offset: usize, + output: &Buffer, mul: f32, add: f32, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Affine, "affine_float")?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + let size: usize = shape.iter().product(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), + set_params!( + encoder, + ( + size, + shape.len(), + shape, + input_stride, + mul, + add, + (input, input_offset), + output ) - .unwrap(); + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_powf( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + size: usize, + input: &Buffer, + output: &Buffer, + mul: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (size, mul, add, input, output)); + set_params!(encoder, (size, mul, input, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_powf_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + input: &Buffer, + input_stride: &[usize], + input_offset: usize, + output: &Buffer, + mul: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + let size: usize = shape.iter().product(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + size, + shape.len(), + shape, + input_stride, + mul, + (input, input_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_elu( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + size: usize, + input: &Buffer, + output: &Buffer, + mul: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (size, mul, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } #[allow(clippy::too_many_arguments)] +pub fn call_elu_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + input: &Buffer, + input_stride: &[usize], + input_offset: usize, + output: &Buffer, + mul: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + let size: usize = shape.iter().product(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + size, + shape.len(), + shape, + input_stride, + mul, + (input, input_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + pub fn call_where_cond_strided( device: &Device, command_buffer: &CommandBufferRef, @@ -582,19 +921,12 @@ pub fn call_where_cond_strided( (left_stride, left_offset): (&[usize], usize), right: &Buffer, (right_stride, right_offset): (&[usize], usize), - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Ternary, name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let size: usize = shape.iter().product(); @@ -618,7 +950,12 @@ pub fn call_where_cond_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(cond, metal::MTLResourceUsage::Read); + encoder.use_resource(left, metal::MTLResourceUsage::Read); + encoder.use_resource(right, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -634,20 +971,18 @@ pub fn call_index_select( dim: usize, input: &Buffer, ids: &Buffer, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = shape[..dim].iter().product(); let right_size: usize = shape[dim + 1..].iter().product(); let src_dim_size = shape[dim]; let dst_el = ids_size * left_size * right_size; - let func = kernels.load_function(device, Source::Indexing, name)?; - let pipeline = device - .new_compute_pipeline_state_with_function(&func) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -666,10 +1001,426 @@ pub fn call_index_select( let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_gather( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + ids_size: usize, + dim: usize, + input: &Buffer, + input_offset: usize, + ids: &Buffer, + ids_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let left_size: usize = shape[..dim].iter().product(); + let right_size: usize = shape[dim + 1..].iter().product(); + let src_dim_size = shape[dim]; + let dst_el = ids_size * left_size * right_size; + + let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; + + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + dst_el, + left_size, + src_dim_size, + right_size, + ids_size, + (input, input_offset), + (ids, ids_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + +pub fn call_scatter_add( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + src_shape: &[usize], + dst_shape: &[usize], + dim: usize, + input: &Buffer, + input_offset: usize, + ids: &Buffer, + ids_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let left_size: usize = src_shape[..dim].iter().product(); + let right_size: usize = src_shape[dim + 1..].iter().product(); + let src_dim_size = src_shape[dim]; + let dst_el = left_size * right_size; + let dst_dim_size = dst_shape[dim]; + + let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; + + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + dst_el, + left_size, + src_dim_size, + right_size, + dst_dim_size, + (input, input_offset), + (ids, ids_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + +pub fn call_index_add( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + src_shape: &[usize], + dst_shape: &[usize], + ids_shape: &[usize], + dim: usize, + input: &Buffer, + input_offset: usize, + ids: &Buffer, + ids_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let left_size: usize = src_shape[..dim].iter().product(); + let right_size: usize = src_shape[dim + 1..].iter().product(); + let src_dim_size = src_shape[dim]; + let dst_el = left_size * right_size; + let dst_dim_size = dst_shape[dim]; + let ids_dim_size = ids_shape[0]; + + let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + dst_el, + left_size, + src_dim_size, + right_size, + dst_dim_size, + ids_dim_size, + (input, input_offset), + (ids, ids_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + +#[derive(Debug, PartialEq)] +pub enum Value { + USize(usize), + Bool(bool), + F32(f32), + U16(u16), +} + +impl std::hash::Hash for Value { + fn hash<H: std::hash::Hasher>(&self, state: &mut H) { + match self { + Value::F32(v) => v.to_bits().hash(state), + Value::USize(v) => v.hash(state), + Value::U16(v) => v.hash(state), + Value::Bool(v) => v.hash(state), + } + } +} + +impl Value { + fn data_type(&self) -> MTLDataType { + match self { + Value::USize(_) => MTLDataType::UInt, + Value::F32(_) => MTLDataType::Float, + Value::U16(_) => MTLDataType::UShort, + Value::Bool(_) => MTLDataType::Bool, + } + } +} + +/// Not true, good enough for our purposes. +impl Eq for Value {} + +#[derive(Debug, Eq, PartialEq, Hash)] +struct ConstantValues(Vec<(usize, Value)>); + +impl ConstantValues { + pub fn new(values: Vec<(usize, Value)>) -> Self { + Self(values) + } + + fn function_constant_values(&self) -> FunctionConstantValues { + let f = FunctionConstantValues::new(); + for (index, value) in &self.0 { + let ty = value.data_type(); + match value { + Value::USize(v) => { + f.set_constant_value_at_index( + v as *const usize as *const c_void, + ty, + *index as u64, + ); + } + Value::F32(v) => { + f.set_constant_value_at_index( + v as *const f32 as *const c_void, + ty, + *index as u64, + ); + } + Value::U16(v) => { + f.set_constant_value_at_index( + v as *const u16 as *const c_void, + ty, + *index as u64, + ); + } + Value::Bool(v) => { + f.set_constant_value_at_index( + v as *const bool as *const c_void, + ty, + *index as u64, + ); + } + } + } + f + } +} + +#[allow(clippy::too_many_arguments)] +pub fn call_gemm( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + (b, m, n, k): (usize, usize, usize, usize), + lhs_stride: &[usize], + lhs_offset: usize, + lhs_buffer: &Buffer, + rhs_stride: &[usize], + rhs_offset: usize, + rhs_buffer: &Buffer, + output: &Buffer, +) -> Result<(), MetalKernelError> { + assert!(rhs_stride.len() >= 2); + assert!(lhs_stride.len() >= 2); + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + let a_trans = if lhs_m1 == 1 && lhs_m2 == k { + false + } else if lhs_m1 == m && lhs_m2 == 1 { + true + } else { + return Err(MetalKernelError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })?; + }; + let b_trans = if rhs_m1 == 1 && rhs_m2 == n { + false + } else if rhs_m1 == k && rhs_m2 == 1 { + true + } else { + return Err(MetalKernelError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })?; + }; + let d_trans = false; + let alpha = 1.0f32; + let beta = 0.0f32; + let batched = b > 1; + let fused_activation = false; + let fused_bias = false; + let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 { + let m_simd = 16; + let n_simd = 8; + let k_simd = 64; + let m_splits = 1; + let n_splits = 1; + (m_simd, n_simd, k_simd, m_splits, n_splits) + } else { + let m_simd = 40; + let n_simd = 40; + let k_simd = 8; + let m_splits = 1; + let n_splits = 1; + (m_simd, n_simd, k_simd, m_splits, n_splits) + }; + let constants = Some(ConstantValues::new(vec![ + (0, Value::USize(m)), + (1, Value::USize(n)), + (2, Value::USize(k)), + (10, Value::Bool(a_trans)), + (11, Value::Bool(b_trans)), + (13, Value::Bool(d_trans)), + (20, Value::F32(alpha)), + (21, Value::F32(beta)), + (100, Value::Bool(batched)), + (101, Value::Bool(fused_activation)), + // Garbage + (102, Value::Bool(false)), + (103, Value::Bool(false)), + (113, Value::Bool(false)), + (50_000, Value::Bool(false)), + // End garbage + (200, Value::U16(m_simd)), + (201, Value::U16(n_simd)), + (202, Value::U16(k_simd)), + (210, Value::U16(m_splits)), + (211, Value::U16(n_splits)), + (50_001, Value::Bool(fused_bias)), + ])); + let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?; + let m_group = m_simd * m_splits; + let n_group = n_simd * n_splits; + + let a_block_length = m_group * k_simd; + let b_block_length = k_simd * n_group; + + let mut block_elements = a_block_length + b_block_length; + if (m % 8 != 0) && (n % 8 != 0) { + let c_block_length = m_group * n_group; + block_elements = std::cmp::max(c_block_length, block_elements) + } + if fused_bias { + if d_trans { + block_elements = std::cmp::max(block_elements, m_group); + } else { + block_elements = std::cmp::max(block_elements, n_group); + } + } + let bytes = match name { + "sgemm" => 4, + "hgemm" => 2, + other => { + return Err(MetalKernelError::LoadLibraryError(format!( + "{other} is not a valid kernel for gemm" + ))); + } + }; + let block_bytes = block_elements * bytes; + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + encoder.set_threadgroup_memory_length(0, block_bytes.into()); + encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); + encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger); + encoder.set_buffer(2, Some(output), 0); + // TODO Tensor D + + let grid_z = b; + if batched { + let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize; + let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize; + let byte_stride_c = m * n * bytes as usize; + // TODO byte_stride_d + let byte_stride_d = 0; + + let mut buffer: Vec<u64> = Vec::with_capacity(b * 4); + for i in 0..b { + buffer.push((i * byte_stride_a) as u64); + buffer.push((i * byte_stride_b) as u64); + buffer.push((i * byte_stride_c) as u64); + buffer.push((i * byte_stride_d) as u64); + } + encoder.set_bytes( + 10, + (buffer.len() * core::mem::size_of::<u64>()) as NSUInteger, + buffer.as_ptr() as *const NSUInteger as *const c_void, + ); + } + + let grid_size = MTLSize { + width: divide(n, n_group.into()), + height: divide(m, m_group.into()), + depth: grid_z as NSUInteger, + }; + let group_size = MTLSize { + width: 32 * (m_splits as u64) * (n_splits as u64), + height: 1, + depth: 1, + }; + // println!("grid size {grid_size:?} group size {group_size:?}"); + encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(grid_size, group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + + Ok(()) +} + +fn divide(m: usize, b: usize) -> NSUInteger { + ((m + b - 1) / b) as NSUInteger +} + #[cfg(test)] mod tests; diff --git a/candle-metal-kernels/src/libMetalFlashAttention.metallib b/candle-metal-kernels/src/libMetalFlashAttention.metallib Binary files differnew file mode 100644 index 00000000..f5116ca6 --- /dev/null +++ b/candle-metal-kernels/src/libMetalFlashAttention.metallib diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index c6984474..2d584917 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -1,6 +1,9 @@ #include <metal_stdlib> using namespace metal; +#define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) + METAL_FUNC uint get_strided_index( uint idx, constant size_t &num_dims, @@ -16,39 +19,160 @@ METAL_FUNC uint get_strided_index( return strided_i; } -constant int THREADGROUP_SIZE = 256; +constant int THREADGROUP_SIZE = 2048; + + +#define ARGMIN(NAME, T, MAXVALUE) \ +kernel void NAME( \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device uint *dst, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + \ + threadgroup T shared_memory[THREADGROUP_SIZE]; \ + threadgroup uint shared_indices[THREADGROUP_SIZE]; \ + \ + shared_memory[tid] = MAXVALUE; \ + shared_indices[tid] = 0xFFFFFFFF; \ + bool notset = true; \ + /* \ + // Elements summed in this block range from dst_id * el_to_sum_per_block \ + // to (dst_id + 1) * el_to_sum_per_block. \ + */ \ + size_t start_idx = dst_id * el_to_sum_per_block; \ + size_t stop_idx = start_idx + el_to_sum_per_block; \ + size_t idx = start_idx + tid; \ + while (idx < stop_idx) { \ + /* \ + // TODO: Fast version for the contiguous case. \ + */ \ + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ + if (notset || src[strided_i] < shared_memory[tid]) { \ + shared_memory[tid] = src[strided_i]; \ + /* Assume that the reduction takes place over the last dimension which is contiguous. */ \ + shared_indices[tid] = idx % dims[num_dims - 1]; \ + notset = false; \ + } \ + idx += block_dim; \ + } \ + \ + threadgroup_barrier(mem_flags::mem_none); \ + \ + /* \ + // reduction in shared memory \ + */ \ + for (uint s = block_dim / 2; s > 0; s >>= 1) { \ + if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \ + shared_indices[tid] = shared_indices[tid + s]; \ + shared_memory[tid] = shared_memory[tid + s]; \ + } \ + threadgroup_barrier(mem_flags::mem_none); \ + } \ + \ + if (tid == 0){ \ + dst[dst_id] = shared_indices[0]; \ + } \ +} \ + -# define REDUCE(FN, NAME, TYPENAME) \ +#define ARGMAX(NAME, T, MINVALUE) \ kernel void NAME( \ - constant size_t &src_numel, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ constant size_t &el_to_sum_per_block, \ - device const TYPENAME *src, \ - device TYPENAME *dst, \ + device const T *src, \ + device uint *dst, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + \ + threadgroup T shared_memory[THREADGROUP_SIZE]; \ + threadgroup uint shared_indices[THREADGROUP_SIZE]; \ + \ + shared_memory[tid] = MINVALUE; \ + shared_indices[tid] = 0xFFFFFFFF; \ + /* \ + // Elements summed in this block range from dst_id * el_to_sum_per_block \ + // to (dst_id + 1) * el_to_sum_per_block. \ + */ \ + size_t start_idx = dst_id * el_to_sum_per_block; \ + size_t stop_idx = start_idx + el_to_sum_per_block; \ + size_t idx = start_idx + tid; \ + bool notset = true; \ + while (idx < stop_idx) { \ + /* \ + // TODO: Fast version for the contiguous case. \ + */ \ + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ + if (notset || shared_memory[tid] < src[strided_i]) { \ + shared_memory[tid] = src[strided_i]; \ + shared_indices[tid] = idx % dims[num_dims - 1]; \ + notset = false; \ + } \ + idx += block_dim; \ + } \ + \ + threadgroup_barrier(mem_flags::mem_none); \ + \ + /* \ + // reduction in shared memory \ + */ \ + for (uint s = block_dim / 2; s > 0; s >>= 1) { \ + if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \ + shared_indices[tid] = shared_indices[tid + s]; \ + shared_memory[tid] = shared_memory[tid + s]; \ + } \ + threadgroup_barrier(mem_flags::mem_none); \ + } \ + \ + if (tid == 0){ \ + dst[dst_id] = shared_indices[0]; \ + } \ +} \ + +#define REDUCE(FN, NAME, T, START) \ +kernel void NAME( \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ uint id [[ thread_position_in_grid ]], \ uint tid [[ thread_index_in_threadgroup ]], \ uint dst_id [[ threadgroup_position_in_grid ]], \ - uint blockDim [[ threads_per_threadgroup ]] \ + uint block_dim [[ threads_per_threadgroup ]] \ ) { \ \ - threadgroup float shared_memory[THREADGROUP_SIZE]; \ + threadgroup T shared_memory[THREADGROUP_SIZE]; \ \ - shared_memory[tid] = 0; \ + shared_memory[tid] = START; \ /* \ // Elements summed in this block range from dst_id * el_to_sum_per_block \ // to (dst_id + 1) * el_to_sum_per_block. \ */ \ size_t start_idx = dst_id * el_to_sum_per_block; \ - size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \ + size_t stop_idx = start_idx + el_to_sum_per_block; \ size_t idx = start_idx + tid; \ while (idx < stop_idx) { \ /* \ // TODO: Fast version for the contiguous case. \ - // size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ */ \ - TYPENAME x = shared_memory[tid]; \ - TYPENAME y = src[idx]; \ + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ + T x = shared_memory[tid]; \ + T y = src[strided_i]; \ shared_memory[tid] = FN; \ - idx += blockDim; \ + idx += block_dim; \ } \ \ threadgroup_barrier(mem_flags::mem_none); \ @@ -56,10 +180,10 @@ kernel void NAME( \ /* \ // reduction in shared memory \ */ \ - for (uint s = blockDim / 2; s > 0; s >>= 1) { \ + for (uint s = block_dim / 2; s > 0; s >>= 1) { \ if (tid < s) { \ - TYPENAME x = shared_memory[tid]; \ - TYPENAME y = shared_memory[tid + s]; \ + T x = shared_memory[tid]; \ + T y = shared_memory[tid + s]; \ shared_memory[tid] = FN; \ } \ threadgroup_barrier(mem_flags::mem_none); \ @@ -68,72 +192,101 @@ kernel void NAME( \ dst[dst_id] = shared_memory[0]; \ } \ -kernel void softmax_float( - constant size_t &src_numel, - constant size_t &el_to_sum_per_block, - device const float *src, - device float *dst, - uint id [[ thread_position_in_grid ]], - uint tid [[ thread_index_in_threadgroup ]], - uint dst_id [[ threadgroup_position_in_grid ]], - uint blockDim [[ threads_per_threadgroup ]] -) { - - threadgroup float shared_memory[THREADGROUP_SIZE]; - - shared_memory[tid] = -INFINITY; - // Elements summed in this block range from dst_id * el_to_sum_per_block - // to (dst_id + 1) * el_to_sum_per_block. - size_t start_idx = dst_id * el_to_sum_per_block; - size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); - size_t idx = start_idx + tid; - - while (idx < stop_idx) { - // TODO: Fast version for the contiguous case. - shared_memory[tid] = max(shared_memory[tid], src[idx]); - idx += blockDim; - } - - threadgroup_barrier(mem_flags::mem_none); - - // reduction in shared memory - for (uint s = blockDim / 2; s > 0; s >>= 1) { - if (tid < s) { - shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]); - } - threadgroup_barrier(mem_flags::mem_none); - } - - float max = shared_memory[0]; - - shared_memory[tid] = 0; - // Restart - idx = start_idx + tid; - while (idx < stop_idx) { - // TODO: Fast version for the contiguous case. - const float val = exp(src[idx] - max); - dst[idx] = val; - shared_memory[tid] += val; - idx += blockDim; - } - // reduction in shared memory - for (uint s = blockDim / 2; s > 0; s >>= 1) { - if (tid < s) { - shared_memory[tid] += shared_memory[tid + s]; - } - threadgroup_barrier(mem_flags::mem_none); - } - - const float inv_acc = 1/shared_memory[0]; - idx = start_idx + tid; - while (idx < stop_idx) { - dst[idx] *= inv_acc; - idx += blockDim; - } -} +#define SOFTMAX(NAME, T) \ +kernel void NAME( \ + constant size_t &src_numel, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ + \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + threadgroup float shared_memory[THREADGROUP_SIZE]; \ + shared_memory[tid] = -INFINITY; \ + size_t start_idx = dst_id * el_to_sum_per_block; \ + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \ + size_t idx = start_idx + tid; \ + \ + \ + float tmp = -INFINITY; \ + while (idx < stop_idx) { \ + tmp = MAX(tmp, float(src[idx])); \ + idx += block_dim; \ + } \ + shared_memory[tid] = tmp; \ + \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + \ + for (uint s = block_dim / 2; s > 0; s >>= 1) { \ + if (tid < s) { \ + shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \ + } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + } \ + \ + /* wait for shared_memory[0] to be filled */ \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + \ + float _max = shared_memory[0]; \ + \ + /* prevent tid=0 from overwriting _max before other threads have written it */ \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + shared_memory[tid] = 0; \ + \ + idx = start_idx + tid; \ + while (idx < stop_idx) { \ + const float val = exp(float(src[idx]) - _max); \ + dst[idx] = T(val); \ + shared_memory[tid] += val; \ + idx += block_dim; \ + } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + for (uint s = block_dim / 2; s > 0; s >>= 1) { \ + if (tid < s) { \ + shared_memory[tid] += shared_memory[tid + s]; \ + } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + } \ + \ + const T inv_acc = T(1.0/shared_memory[0]); \ + idx = start_idx + tid; \ + while (idx < stop_idx) { \ + dst[idx] *= inv_acc; \ + idx += block_dim; \ + } \ +} \ +REDUCE(x + y, fast_sum_f32_strided, float, 0) +REDUCE(x + y, fast_sum_u32_strided, uint, 0) +REDUCE(x + y, fast_sum_f16_strided, half, 0) +REDUCE(x * y, fast_mul_f32_strided, float, 1) +REDUCE(x * y, fast_mul_u32_strided, uint, 1) +REDUCE(x * y, fast_mul_f16_strided, half, 1) +REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF) +REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0) +REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH) +REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF) +REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF) +REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH) +ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF) +ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH) +ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF) +ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF) +ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH) +ARGMAX(fast_argmax_u32_strided, uint, 0) -REDUCE(x + y, fast_sum_float, float) -REDUCE(x * y, fast_mul_float, float) -REDUCE(max(x, y), fast_max_float, float) +SOFTMAX(softmax_f32, float) +SOFTMAX(softmax_f16, half) +#if __METAL_VERSION__ >= 310 +REDUCE(x + y, fast_sum_bf16, bfloat, 0) +REDUCE(x * y, fast_mul_bf16, bfloat, 1) +REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF) +REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF) +ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF) +ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF) +SOFTMAX(softmax_bf16, bfloat) +#endif diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal index 0945b355..1f9cb38a 100644 --- a/candle-metal-kernels/src/ternary.metal +++ b/candle-metal-kernels/src/ternary.metal @@ -32,6 +32,9 @@ kernel void FN_NAME( \ device TYPENAME *out ,\ uint i [[ thread_position_in_grid ]] \ ) { \ + if (i >= numel){ \ + return; \ + } \ uint strided_i = get_strided_index(i, num_dims, dims, strides); \ uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \ uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \ diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 2330d48d..1b3153b1 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1,7 +1,14 @@ use super::*; -use half::f16; +use half::{bf16, f16}; use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger}; +fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> { + let ptr = buffer.contents() as *const T; + assert!(!ptr.is_null()); + let slice = unsafe { std::slice::from_raw_parts(ptr, n) }; + slice.to_vec() +} + fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer { let options = MTLResourceOptions::StorageModeManaged; let ptr = data.as_ptr() as *const core::ffi::c_void; @@ -23,13 +30,19 @@ fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> { v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect() } +fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> { + let b = 10f32.powi(digits); + v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect() +} + fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let output = new_buffer(&device, v); call_unary_contiguous( &device, command_buffer, @@ -37,23 +50,24 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> { name, v.len(), &input, - &mut output, + &output, ) .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::<T>(v.len()) + read_to_vec(&output, v.len()) } fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged; let left = new_buffer(&device, x); let right = new_buffer(&device, y); - let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options); + let output = device.new_buffer(std::mem::size_of_val(x) as u64, options); call_binary_contiguous( &device, command_buffer, @@ -62,12 +76,12 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V x.len(), &left, &right, - &mut output, + &output, ) .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::<T>(x.len()) + read_to_vec(&output, x.len()) } fn run_strided<T: Clone>( @@ -81,8 +95,9 @@ fn run_strided<T: Clone>( let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); - let kernels = Kernels::new(); + let output = new_buffer(&device, v); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); call_unary_strided( &device, command_buffer, @@ -92,13 +107,13 @@ fn run_strided<T: Clone>( &input, strides, offset, - &mut output, + &output, 0, ) .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::<T>(v.len()) + read_to_vec(&output, v.len()) } #[test] @@ -201,6 +216,25 @@ fn cos_strided_random() { } #[test] +fn gelu_f16() { + let v: Vec<f16> = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let expected: Vec<f32> = vec![-0.0, -0.16, 0.0, 0.84, 1.96, 3.0, 10.0, 20.0]; + let results = run(&v, unary::contiguous::gelu::HALF); + assert_eq!(approx_f16(results, 2), expected); +} + +#[test] +fn gelu_f32() { + let v: Vec<f32> = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]; + let expected: Vec<f32> = vec![-0.0, -0.159, 0.0, 0.841, 1.955, 2.996, 10.0, 20.0]; + let results = run(&v, unary::contiguous::gelu::FLOAT); + assert_eq!(approx(results, 3), expected); +} + +#[test] fn binary_add_f32() { let left = vec![1.0f32, 2.0, 3.0]; let right = vec![2.0f32, 3.1, 4.2]; @@ -216,11 +250,14 @@ fn binary_add_f32() { fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let options = MTLResourceOptions::StorageModeManaged; + let size = (v.len() * std::mem::size_of::<U>()) as u64; + let output = device.new_buffer(size, options); call_cast_contiguous( &device, @@ -229,12 +266,13 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> { name, v.len(), &input, - &mut output, + 0, + &output, ) .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::<U>(v.len()) + read_to_vec(&output, v.len()) } #[test] @@ -245,21 +283,28 @@ fn cast_u32_f32() { assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]); assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]); + let v = vec![1.0f32, 2.0, 3.0]; + let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect(); + let results: Vec<f32> = cast(&input, "cast_f16_f32"); + assert_eq!(results, vec![1.0f32, 2.0, 3.0]); + let v = vec![1.0f32; 10_000]; - let results = run(&v, unary::contiguous::cos::FLOAT); - let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); - assert_eq!(approx(results, 4), vec![0.5403; 10_000]); - assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); + let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect(); + let results: Vec<f32> = cast(&input, "cast_f16_f32"); + assert_eq!(results.len(), 10_000); + assert_eq!(&results[..10], vec![1.0f32; 10]); + assert_eq!(results, vec![1.0f32; 10_000]); } fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let output = new_buffer(&device, v); let size = v.len(); @@ -267,9 +312,46 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> { &device, command_buffer, &kernels, + "affine_f32", size, &input, - &mut output, + &output, + mul as f32, + add as f32, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&output, v.len()) +} + +fn run_affine_strided<T: Clone>( + v: &[T], + shape: &[usize], + strides: &[usize], + mul: f64, + add: f64, +) -> Vec<T> { + let device = device(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let input = new_buffer(&device, v); + let output = new_buffer(&device, v); + + call_affine_strided( + &device, + command_buffer, + &kernels, + "affine_f32_strided", + shape, + &input, + strides, + 0, + &output, mul as f32, add as f32, ) @@ -277,7 +359,8 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> { command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::<T>(v.len()) + let len: usize = shape.iter().product(); + read_to_vec(&output, len) } #[test] @@ -296,6 +379,18 @@ fn affine() { } #[test] +fn affine_strided() { + let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let mul = 1.5; + let add = 1.1; + let shape = [4]; + let strides = [2]; + let result = run_affine_strided(&input, &shape, &strides, mul, add); + // 1 on 2 + assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]); +} + +#[test] fn index_select() { let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let shape = [5, 2]; @@ -313,7 +408,26 @@ fn index_select() { result, vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] ); +} +#[test] +fn index_select_f16() { + let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + .into_iter() + .map(|x| f16::from_f32(x)) + .collect(); + let shape = [5, 2]; + let ids = [0u32, 4, 2]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim); + assert_eq!( + approx_f16(result, 4), + vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] + ); +} + +#[test] +fn index_select_dim1() { let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let shape = [5, 2]; let ids = [0u32, 1, 0]; @@ -321,7 +435,7 @@ fn index_select() { let result = run_index_select(&embedding, &shape, &ids, dim); assert_eq!( result, - vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] + vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0] ); } @@ -341,27 +455,34 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>( let left_size: usize = shape[..dim].iter().product(); let right_size: usize = shape[dim + 1..].iter().product(); let dst_el = ids.len() * left_size * right_size; - let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); + let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); + + let name = match core::mem::size_of::<T>() { + 4 => "is_u32_f32", + 2 => "is_u32_f16", + _ => unimplemented!(), + }; - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); call_index_select( &device, &command_buffer, &kernels, - "is_u32_f32", + name, shape, ids.len(), dim, &embeddings_buffer, &ids_buffer, - &mut dst_buffer, + &dst_buffer, ) .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); - dst_buffer.read_to_vec::<T>(dst_el) + read_to_vec(&dst_buffer, dst_el) } #[test] @@ -427,7 +548,7 @@ fn index_add() { let expected = vec![ 2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0, ]; - let result = outputs_buffer.read_to_vec::<f32>(right.len()); + let result: Vec<f32> = read_to_vec(&outputs_buffer, right.len()); assert_eq!(result, expected); } @@ -439,43 +560,49 @@ fn cos_f16() { .collect(); let results = run(&v, unary::contiguous::cos::HALF); let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect(); - assert_eq!(approx_f16(results, 4), vec![0.5405, -0.4163, -0.9902]); - assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]); + assert_eq!(approx_f16(results, 2), vec![0.54, -0.42, -0.99]); + assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]); } fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); let options = MTLResourceOptions::StorageModeManaged; - let mut output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options); - call_reduce_contiguous( + let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options); + let dims = vec![v.len()]; + let strides = vec![1]; + call_reduce_strided( &device, command_buffer, &kernels, name, - v.len(), + &dims, + &strides, out_length, &input, - &mut output, + 0, + &output, ) .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::<T>(out_length) + read_to_vec(&output, out_length) } fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let output = new_buffer(&device, v); call_last_softmax( &device, command_buffer, @@ -484,13 +611,14 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta v.len(), last_dim, &input, - &mut output, + 0, + &output, ) .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::<T>(v.len()) + read_to_vec(&output, v.len()) } #[test] @@ -498,7 +626,7 @@ fn reduce_sum() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let out_length = 1; - let results = run_reduce(&v, out_length, "fast_sum_float"); + let results = run_reduce(&v, out_length, "fast_sum_f32_strided"); assert_eq!(approx(results, 4), vec![21.0]); } @@ -507,7 +635,7 @@ fn reduce_sum2() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let out_length = 2; - let results = run_reduce(&v, out_length, "fast_sum_float"); + let results = run_reduce(&v, out_length, "fast_sum_f32_strided"); assert_eq!(approx(results, 4), vec![6.0, 15.0]); } @@ -515,15 +643,33 @@ fn reduce_sum2() { fn softmax() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let last_dim = 6; - let results = run_softmax(&v, last_dim, "softmax_float"); + let results = run_softmax(&v, last_dim, "softmax_f32"); assert_eq!( approx(results, 4), vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] ); + let last_dim = 4096; + let n = 200; + let mut v = vec![0.0; n * last_dim]; + for i in 0..n { + v[i * last_dim] = 20.0; + } + let results = run_softmax(&v, last_dim, "softmax_f32"); + let results = approx(results, 4); + println!("{results:?}"); + assert_eq!( + results.iter().map(|&s| s.round() as usize).sum::<usize>(), + n + ); + assert_eq!(results[0], 1.0); + assert_eq!(results[1], 0.0); + assert_eq!(results[last_dim], 1.0); + assert_eq!(results[2 * last_dim], 1.0); + let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]; let last_dim = 6; - let results = run_softmax(&v, last_dim, "softmax_float"); + let results = run_softmax(&v, last_dim, "softmax_f32"); assert_eq!( approx(results, 4), vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] @@ -531,11 +677,33 @@ fn softmax() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let last_dim = 3; - let results = run_softmax(&v, last_dim, "softmax_float"); + let results = run_softmax(&v, last_dim, "softmax_f32"); assert_eq!( approx(results, 4), vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652] ); + + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect::<Vec<_>>(); + let last_dim = 6; + let results = run_softmax(&v, last_dim, "softmax_f16"); + assert_eq!( + approx_f16(results, 4), + vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338] + ); + + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] + .iter() + .map(|v| bf16::from_f32(*v)) + .collect::<Vec<_>>(); + let last_dim = 6; + let results = run_softmax(&v, last_dim, "softmax_bf16"); + assert_eq!( + approx_bf16(results, 4), + vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328] + ); } fn run_where_cond<I: Clone, T: Clone>( @@ -549,7 +717,8 @@ fn run_where_cond<I: Clone, T: Clone>( name: &'static str, ) -> Vec<T> { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged; @@ -571,7 +740,7 @@ fn run_where_cond<I: Clone, T: Clone>( options, ); - let mut output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options); + let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options); call_where_cond_strided( &device, command_buffer, @@ -584,13 +753,13 @@ fn run_where_cond<I: Clone, T: Clone>( (&left_stride, left_offset), &right, (&cond_stride, cond_offset), - &mut output, + &output, ) .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::<T>(length) + read_to_vec(&output, length) } #[test] @@ -614,3 +783,93 @@ fn where_cond() { ); assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); } + +fn run_gemm<T: Clone>( + (b, m, n, k): (usize, usize, usize, usize), + lhs: &[T], + lhs_stride: Vec<usize>, + lhs_offset: usize, + rhs: &[T], + rhs_stride: Vec<usize>, + rhs_offset: usize, +) -> Vec<T> { + let device = device(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + + let lhs = device.new_buffer_with_data( + lhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(lhs) as u64, + options, + ); + let rhs = device.new_buffer_with_data( + rhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(rhs) as u64, + options, + ); + let length = b * m * n; + let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options); + call_gemm( + &device, + command_buffer, + &kernels, + "sgemm", + (b, m, n, k), + &lhs_stride, + lhs_offset, + &lhs, + &rhs_stride, + rhs_offset, + &rhs, + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&output, length) +} + +#[test] +fn gemm() { + let (b, m, n, k) = (1, 2, 4, 3); + let lhs_stride = vec![m * k, k, 1]; + let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect(); + let rhs_stride = vec![n * k, n, 1]; + let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect(); + let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0); + assert_eq!( + approx(results, 4), + vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] + ); + + let (b, m, n, k) = (2, 2, 4, 3); + let lhs_stride = vec![m * k, k, 1]; + let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect(); + let rhs_stride = vec![n * k, n, 1]; + let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect(); + let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0); + assert_eq!( + approx(results, 4), + vec![ + 20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0, + 518.0, 548.0, 578.0 + ] + ); + + // OFFSET + let (b, m, n, k) = (2, 2, 4, 3); + let lhs_stride = vec![m * k, k, 1]; + let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect(); + let rhs_stride = vec![n * k, n, 1]; + let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect(); + // Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32 + let results = run_gemm((1, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 12 * 4); + assert_eq!( + approx(results, 4), + vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0] + ); +} diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index eb6424e8..553bc506 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -1,4 +1,7 @@ #include <metal_stdlib> +#include <metal_math> +# +using namespace metal; METAL_FUNC uint get_strided_index( uint idx, @@ -17,10 +20,44 @@ METAL_FUNC uint get_strided_index( template <typename T> METAL_FUNC T sqr(T in){ return in * in; } template <typename T> METAL_FUNC T neg(T in){ return -in; } -template <typename T> METAL_FUNC T id(T in){ return in; } +template <typename T> METAL_FUNC T erf(T in){ + float x = (float) in; + // constants + float a1 = 0.254829592; + float a2 = -0.284496736; + float a3 = 1.421413741; + float a4 = -1.453152027; + float a5 = 1.061405429; + float p = 0.3275911; + + // Save the sign of x + int sign = 1; + if (x < 0) + sign = -1; + x = fabs(x); + + // A&S formula 7.1.26 + float t = 1.0/(1.0 + p*x); + float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x); + + return T(sign*y); +} +template <typename T> METAL_FUNC T id(T in) { return in; } +template <typename T> METAL_FUNC T gelu_erf(T x) { + return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); +} +template <typename T> METAL_FUNC T gelu(T x) { + if (x > 5) { + return x; + } + T x_sq = x * x; + T x_cube = x_sq * x; + T alpha = x + static_cast<T>(0.044715) * x_cube; + T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha); + return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta))); +} -using namespace metal; #define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ kernel void FN_NAME( \ @@ -32,7 +69,7 @@ kernel void FN_NAME( \ if (thread_position_in_grid >= dim) { \ return; \ } \ - output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \ + output[thread_position_in_grid] = TYPENAME(FN(float(input[thread_position_in_grid]))); \ }\ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -46,15 +83,15 @@ kernel void FN_NAME_STRIDED( \ if (thread_position_in_grid >= dim) { \ return; \ } \ - output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \ + output[thread_position_in_grid] = TYPENAME(FN(float(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)]))); \ } #define UNARY_OP(NAME) \ -UNARY(NAME, float, NAME##_float, NAME##_float_strided); \ -UNARY(NAME, half, NAME##_half, NAME##_half_strided); +UNARY(NAME, float, NAME##_f32, NAME##_f32_strided); \ +UNARY(NAME, half, NAME##_f16, NAME##_f16_strided); #define BFLOAT_UNARY_OP(NAME) \ -UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided); +UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided); UNARY_OP(cos) @@ -64,8 +101,17 @@ UNARY_OP(sqrt) UNARY_OP(neg) UNARY_OP(exp) UNARY_OP(log) -UNARY(id, float, copy_float, copy_float_strided) -UNARY(id, half, copy_half, copy_half_strided) +UNARY_OP(gelu) +UNARY_OP(ceil) +UNARY_OP(floor) +UNARY_OP(round) +UNARY_OP(gelu_erf) +UNARY_OP(erf) +UNARY_OP(tanh) +UNARY(id, float, copy_f32, copy_f32_strided) +UNARY(id, half, copy_f16, copy_f16_strided) +UNARY(id, uint8_t, copy_u8, copy_u8_strided) +UNARY(id, uint32_t, copy_u32, copy_u32_strided) #if __METAL_VERSION__ >= 310 BFLOAT_UNARY_OP(cos) @@ -75,6 +121,13 @@ BFLOAT_UNARY_OP(sqrt) BFLOAT_UNARY_OP(neg) BFLOAT_UNARY_OP(exp) BFLOAT_UNARY_OP(log) +BFLOAT_UNARY_OP(gelu) +BFLOAT_UNARY_OP(ceil) +BFLOAT_UNARY_OP(floor) +BFLOAT_UNARY_OP(round) +BFLOAT_UNARY_OP(gelu_erf) +BFLOAT_UNARY_OP(erf) +BFLOAT_UNARY_OP(tanh) -UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided) +UNARY(id, bfloat, copy_bf16, copy_bf16_strided) #endif diff --git a/candle-metal-kernels/examples/affine.rs b/candle-metal-kernels/tmp/affine.rs index b8005dc0..cd019056 100644 --- a/candle-metal-kernels/examples/affine.rs +++ b/candle-metal-kernels/tmp/affine.rs @@ -50,6 +50,7 @@ fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) { &device, command_buffer, &kernels, + "affine_float", v.len(), &input, &mut output, diff --git a/candle-metal-kernels/examples/binary.rs b/candle-metal-kernels/tmp/binary.rs index af5a8bdc..af5a8bdc 100644 --- a/candle-metal-kernels/examples/binary.rs +++ b/candle-metal-kernels/tmp/binary.rs diff --git a/candle-metal-kernels/examples/cast.rs b/candle-metal-kernels/tmp/cast.rs index 090f510d..090f510d 100644 --- a/candle-metal-kernels/examples/cast.rs +++ b/candle-metal-kernels/tmp/cast.rs diff --git a/candle-metal-kernels/examples/unary.rs b/candle-metal-kernels/tmp/unary.rs index 7039c098..66cf25c0 100644 --- a/candle-metal-kernels/examples/unary.rs +++ b/candle-metal-kernels/tmp/unary.rs @@ -147,7 +147,7 @@ fn run_unary_bench<T: Clone>( println!( "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", type_name::<T>().split("::").last().unwrap(), - kernel_name.to_string(), + kernel_name.0, v.len(), iterations, total_time, @@ -159,7 +159,7 @@ fn run_unary_bench<T: Clone>( let shape = vec![2, 5_000]; let strides = vec![2, 1]; let offset = 0; - for kernel_name in strided { + for kernel_name in &strided { let total_time = autoreleasepool(|| { let command_buffer = command_queue.new_command_buffer(); let start = Instant::now(); @@ -187,7 +187,7 @@ fn run_unary_bench<T: Clone>( println!( "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", type_name::<T>().split("::").last().unwrap(), - kernel_name.to_string(), + kernel_name.0, v.len(), iterations, total_time, diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index ffbe0ca1..e0daabef 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -19,6 +19,8 @@ num-traits = { workspace = true } rayon = { workspace = true } safetensors = { workspace = true } serde = { workspace = true } +metal = { workspace = true, optional = true } +candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true } [dev-dependencies] anyhow = { workspace = true } @@ -29,3 +31,4 @@ default = [] accelerate = ["dep:accelerate-src", "candle/accelerate"] cuda = ["candle/cuda"] mkl = ["dep:intel-mkl-src", "candle/mkl"] +metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"] diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index a0269e59..abe33350 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -201,6 +201,47 @@ impl candle::CustomOp1 for SoftmaxLastDim { }; Ok((dst, layout.shape().clone())) } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + storage: &candle::MetalStorage, + layout: &Layout, + ) -> Result<(candle::MetalStorage, Shape)> { + use candle::{backend::BackendStorage, DType}; + let device = storage.device(); + let command_buffer = device.command_buffer()?; + let kernels = device.kernels(); + let name = match storage.dtype() { + DType::F32 => "softmax_f32", + DType::F16 => "softmax_f16", + DType::BF16 => "softmax_bf16", + dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"), + }; + + let n = layout.stride().len(); + if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) { + candle::bail!("Non contiguous softmax-last-dim is not implemented"); + } + + let last_dim = layout.dims()[layout.shape().rank() - 1]; + let elem_count = layout.shape().elem_count(); + let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?; + candle_metal_kernels::call_last_softmax( + device.metal_device(), + &command_buffer, + kernels, + name, + elem_count, + last_dim, + storage.buffer(), + layout.start_offset() * storage.dtype().size_in_bytes(), + &output, + ) + .unwrap(); + let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype()); + Ok((newstorage, layout.shape().clone())) + } } pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> { diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index 37e5fa65..000702f9 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -31,3 +31,4 @@ accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"] cuda = ["candle/cuda", "candle-nn/cuda"] flash-attn = ["cuda", "dep:candle-flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"] +metal = ["candle/metal", "candle-nn/metal"] |