diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/device.rs | 7 | ||||
-rw-r--r-- | candle-core/src/metal_backend.rs | 1240 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 5 |
3 files changed, 867 insertions, 385 deletions
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 3eb7f8b7..1e33021b 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -201,10 +201,9 @@ impl Device { Ok(Storage::Cuda(storage)) } } - Device::Metal(_device) => { - // let storage = device.rand_uniform(shape, dtype, lo, up)?; - // Ok(Storage::Metal(storage)) - crate::bail!("Metal rand_uniform not implemented") + Device::Metal(device) => { + let storage = device.rand_uniform(shape, dtype, lo, up)?; + Ok(Storage::Metal(storage)) } } } diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 0b72f080..27b2824f 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -4,11 +4,30 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; use candle_metal_kernels::Kernels; -use core::mem; -use half::{bf16, f16}; use metal; -use metal::{Buffer, CommandQueue, MTLResourceOptions, NSUInteger}; -use std::sync::Arc; +use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; +use std::collections::HashMap; +use std::path::Path; +use std::sync::{Arc, RwLock, TryLockError}; + +/// Simple way to catch lock error without +/// depending on T +#[derive(thiserror::Error, Debug)] +pub enum LockError { + #[error("{0}")] + Poisoned(String), + #[error("Would block")] + WouldBlock, +} + +impl<T> From<TryLockError<T>> for MetalError { + fn from(value: TryLockError<T>) -> Self { + match value { + TryLockError::Poisoned(p) => MetalError::LockError(LockError::Poisoned(p.to_string())), + TryLockError::WouldBlock => MetalError::LockError(LockError::WouldBlock), + } + } +} /// Metal related errors #[derive(thiserror::Error, Debug)] @@ -24,6 +43,14 @@ pub enum MetalError { rhs_stride: Vec<usize>, mnk: (usize, usize, usize), }, + #[error("{0:?}")] + LockError(LockError), + #[error("{msg}, expected: {expected:?}, got: {got:?}")] + UnexpectedDType { + msg: &'static str, + expected: DType, + got: DType, + }, } impl From<String> for MetalError { @@ -32,11 +59,53 @@ impl From<String> for MetalError { } } +type AllocatedBuffers = Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>; + #[derive(Clone)] pub struct MetalDevice { + /// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc> device: metal::Device, + + /// Single command queue for the entire device. command_queue: metal::CommandQueue, + /// One command buffer at a time. + /// The scheduler works by allowing multiple + /// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) + /// on a single command buffer. Using a single command buffer would be fastest on the GPU but + /// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed + /// to start to work). + /// Despite what the documentation says, command buffers are NOT ordered. They are ordered + /// for their START time, but there's no guarantee that command buffer1 will finish before + /// command buffer2 starts (or there are metal bugs there) + command_buffer: Arc<RwLock<metal::CommandBuffer>>, + /// Keeps track of the current amount of compute command encoders on the current + /// command buffer + /// Arc, RwLock because of the interior mutability. + command_buffer_index: Arc<RwLock<usize>>, + /// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc) + compute_per_buffer: usize, + /// Every compute command encoder (and blit encoders) are defended with this Fence, forcing the + /// execution order to be linear. + /// It could be relaxed in some circumstances, by managing ourselves the dependencies in the + /// compute graph. + fence: metal::Fence, + /// Simple keeper struct to keep track of the already compiled kernels so we can reuse them. + /// Heavily used by [`candle_metal_kernels`], both fences need to match kernels: Arc<candle_metal_kernels::Kernels>, + /// Simple allocator struct. + /// The buffers are stored in size buckets since ML tends to use similar shapes over and over. + /// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting + /// (could be linked to FFI communication overhead). + /// + /// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the + /// graph calculation, and only we the allocator kept a reference to it, therefore it's free + /// to be reused. However, in order for this to work, we need to guarantee the order of + /// operation, so that this buffer is not being used by another kernel at the same time. + /// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things. + /// + /// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers + /// (strong_count = 1). + buffers: AllocatedBuffers, } impl std::fmt::Debug for MetalDevice { @@ -58,10 +127,47 @@ impl MetalDevice { self.registry_id() } + pub fn metal_device(&self) -> &metal::Device { + &self.device + } + pub fn command_queue(&self) -> &CommandQueue { &self.command_queue } + pub fn command_buffer(&self) -> Result<CommandBuffer> { + let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?; + let mut command_buffer = command_buffer_lock.to_owned(); + let mut index = self + .command_buffer_index + .try_write() + .map_err(MetalError::from)?; + if *index > self.compute_per_buffer { + command_buffer.commit(); + command_buffer = self.command_queue.new_command_buffer().to_owned(); + *command_buffer_lock = command_buffer.clone(); + *index = 0; + } + *index += 1; + Ok(command_buffer) + } + + pub fn wait_until_completed(&self) -> Result<()> { + let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?; + match command_buffer.status() { + metal::MTLCommandBufferStatus::Committed + | metal::MTLCommandBufferStatus::Scheduled + | metal::MTLCommandBufferStatus::Completed => { + panic!("Already committed"); + } + _ => {} + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + *command_buffer = self.command_queue.new_command_buffer().to_owned(); + Ok(()) + } + pub fn kernels(&self) -> &Kernels { &self.kernels } @@ -70,17 +176,119 @@ impl MetalDevice { &self.device } - pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { + /// Creates a new buffer (not necessarily zeroed). + /// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode) + /// This means the buffer data cannot be read on the CPU directly. + /// + /// [`name`] is only used to keep track of the resource origin in case of bugs + pub fn new_buffer( + &self, + element_count: usize, + dtype: DType, + name: &str, + ) -> Result<Arc<Buffer>> { let size = (element_count * dtype.size_in_bytes()) as NSUInteger; - self.device - .new_buffer(size, MTLResourceOptions::StorageModeManaged) + self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name) + } + + /// Creates a new buffer (not necessarily zeroed). + /// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode) + /// This means the buffer can be read on the CPU but will require manual + /// synchronization when the CPU memory is modified + /// Used as a bridge to gather data back from the GPU + pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> { + self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed") + } + + /// Creates a new buffer from data. + /// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode) + /// + /// This method will block the computation because of the + /// lack of lifetime management through the GPU. + /// Internal comment for technical details. + pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> { + let size = core::mem::size_of_val(data) as NSUInteger; + let tmp = self.device.new_buffer_with_data( + data.as_ptr() as *const core::ffi::c_void, + size, + metal::MTLResourceOptions::StorageModeManaged, + ); + let real = self.allocate_buffer( + size, + metal::MTLResourceOptions::StorageModePrivate, + "with_data", + )?; + let command_buffer = self.command_buffer()?; + command_buffer.set_label("with_data"); + let blit = command_buffer.new_blit_command_encoder(); + blit.wait_for_fence(&self.fence); + blit.set_label("with_data_blit"); + blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); + blit.update_fence(&self.fence); + blit.end_encoding(); + + // This is necessary, for mmaped safetensors + // Because of the unsafe slice cast we're doing. + // The slice might not live long enough for metal + // To actually fill the GPU buffer. + // Putting this wait forces the GPU buffer to be filled + // with the actual data allowing the CPU storage todo + // deallocate properly. + self.wait_until_completed()?; + Ok(real) + } + + /// The critical allocator algorithm + fn allocate_buffer( + &self, + size: NSUInteger, + option: MTLResourceOptions, + _name: &str, + ) -> Result<Arc<Buffer>> { + let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; + let subbuffers = buffers.entry((size, option)).or_insert(vec![]); + + for sub in &mut *subbuffers { + if Arc::strong_count(sub) == 1 { + return Ok(sub.clone()); + } + } + let new_buffer = self.device.new_buffer(size as NSUInteger, option); + let new_buffer = Arc::new(new_buffer); + subbuffers.push(new_buffer.clone()); + for subbuffers in buffers.values_mut() { + let newbuffers = subbuffers + .iter() + .filter(|s| Arc::strong_count(s) > 1) + .map(Arc::clone) + .collect(); + *subbuffers = newbuffers; + } + Ok(new_buffer) + } + + /// Create a metal GPU capture trace on [`path`]. + pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> { + let capture = metal::CaptureManager::shared(); + let descriptor = metal::CaptureDescriptor::new(); + descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); + descriptor.set_capture_device(self); + descriptor.set_output_url(path); + + capture + .start_capture(&descriptor) + .map_err(MetalError::from)?; + Ok(()) } } #[derive(Debug, Clone)] pub struct MetalStorage { - buffer: metal::Buffer, + /// The actual buffer containing the data. + buffer: Arc<metal::Buffer>, + /// a reference to the device owning this buffer device: MetalDevice, + /// The dtype is kept since buffers are untyped. dtype: DType, } @@ -108,14 +316,27 @@ impl BackendStorage for MetalStorage { self.dtype ); } + let buffer = self.device.new_buffer_managed(self.buffer.length())?; + { + let command_buffer = self.device.command_buffer()?; + command_buffer.set_label("to_cpu"); + let blit = command_buffer.new_blit_command_encoder(); + blit.set_label("blit_to_cpu"); + blit.wait_for_fence(&self.device.fence); + blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); + blit.update_fence(&self.device.fence); + blit.end_encoding(); + } + self.device.wait_until_completed()?; + match self.dtype { - DType::U8 => Ok(CpuStorage::U8(self.buffer.read_to_vec(length / size))), - DType::U32 => Ok(CpuStorage::U32(self.buffer.read_to_vec(length / size))), - DType::I64 => Ok(CpuStorage::I64(self.buffer.read_to_vec(length / size))), - DType::F16 => Ok(CpuStorage::F16(self.buffer.read_to_vec(length / size))), - DType::BF16 => Ok(CpuStorage::BF16(self.buffer.read_to_vec(length / size))), - DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(length / size))), - DType::F64 => Ok(CpuStorage::F64(self.buffer.read_to_vec(length / size))), + DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))), + DType::U32 => Ok(CpuStorage::U32(read_to_vec(&buffer, length / size))), + DType::I64 => Ok(CpuStorage::I64(read_to_vec(&buffer, length / size))), + DType::F16 => Ok(CpuStorage::F16(read_to_vec(&buffer, length / size))), + DType::BF16 => Ok(CpuStorage::BF16(read_to_vec(&buffer, length / size))), + DType::F32 => Ok(CpuStorage::F32(read_to_vec(&buffer, length / size))), + DType::F64 => Ok(CpuStorage::F64(read_to_vec(&buffer, length / size))), } } @@ -126,52 +347,152 @@ impl BackendStorage for MetalStorage { let el = shape.elem_count(); let dtype = self.dtype; - if layout.is_contiguous() || layout.start_offset() != 0 || dtype != DType::F32 { - crate::bail!("Not contiguous, non-f32 affine is not implemented yet."); + let buffer = device.new_buffer(el, self.dtype, "affine")?; + let command_buffer = self.device.command_buffer()?; + if layout.is_contiguous() && layout.start_offset() == 0 { + let name = match self.dtype { + DType::F32 => "affine_f32", + DType::F16 => "affine_f16", + dtype => crate::bail!("Affine {dtype:?}"), + }; + candle_metal_kernels::call_affine( + &device.device, + &command_buffer, + &device.kernels, + name, + el, + &self.buffer, + &buffer, + mul as f32, + add as f32, + ) + .map_err(MetalError::from)?; + } else { + let name = match self.dtype { + DType::F32 => "affine_f32_strided", + DType::F16 => "affine_f16_strided", + dtype => crate::bail!("Affine {dtype:?}"), + }; + candle_metal_kernels::call_affine_strided( + &device.device, + &command_buffer, + &device.kernels, + name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * dtype.size_in_bytes(), + &buffer, + mul as f32, + add as f32, + ) + .map_err(MetalError::from)?; } - - let mut buffer = device.new_buffer(el, self.dtype); - let command_buffer = self.device.command_queue.new_command_buffer(); - candle_metal_kernels::call_affine( - &device.device, - &command_buffer, - &device.kernels, - el, - &self.buffer, - &mut buffer, - mul as f32, - add as f32, - ) - .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); - return Ok(Self { - buffer, - device: device.clone(), - dtype, - }); + Ok(Self::new(buffer, device.clone(), dtype)) } - fn powf(&self, _: &Layout, _: f64) -> Result<Self> { - crate::bail!("powf metal") + fn powf(&self, layout: &Layout, pow: f64) -> Result<Self> { + let device = self.device().clone(); + + let shape = layout.shape(); + let el = shape.elem_count(); + let dtype = self.dtype; + + let buffer = device.new_buffer(el, self.dtype, "powf")?; + let command_buffer = self.device.command_buffer()?; + if layout.is_contiguous() && layout.start_offset() == 0 { + let name = match self.dtype { + DType::F32 => "powf_f32", + DType::F16 => "powf_f16", + dtype => crate::bail!("Powf {dtype:?}"), + }; + candle_metal_kernels::call_powf( + &device.device, + &command_buffer, + &device.kernels, + name, + el, + &self.buffer, + &buffer, + pow as f32, + ) + .map_err(MetalError::from)?; + } else { + let name = match self.dtype { + DType::F32 => "powf_f32_strided", + DType::F16 => "powf_f16_strided", + dtype => crate::bail!("Powf {dtype:?}"), + }; + candle_metal_kernels::call_powf_strided( + &device.device, + &command_buffer, + &device.kernels, + name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * dtype.size_in_bytes(), + &buffer, + pow as f32, + ) + .map_err(MetalError::from)?; + } + Ok(Self::new(buffer, device.clone(), dtype)) } - fn elu(&self, _: &Layout, _: f64) -> Result<Self> { - crate::bail!("elu metal") + fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> { + let device = self.device().clone(); + + let shape = layout.shape(); + let el = shape.elem_count(); + let dtype = self.dtype; + + let buffer = device.new_buffer(el, self.dtype, "elu")?; + let command_buffer = self.device.command_buffer()?; + if layout.is_contiguous() && layout.start_offset() == 0 { + let name = match self.dtype { + DType::F32 => "elu_f32", + DType::F16 => "elu_f16", + dtype => crate::bail!("Powf {dtype:?}"), + }; + candle_metal_kernels::call_elu( + &device.device, + &command_buffer, + &device.kernels, + name, + el, + &self.buffer, + &buffer, + alpha as f32, + ) + .map_err(MetalError::from)?; + } else { + let name = match self.dtype { + DType::F32 => "elu_f32_strided", + DType::F16 => "elu_f16_strided", + dtype => crate::bail!("Powf {dtype:?}"), + }; + candle_metal_kernels::call_elu_strided( + &device.device, + &command_buffer, + &device.kernels, + name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * dtype.size_in_bytes(), + &buffer, + alpha as f32, + ) + .map_err(MetalError::from)?; + } + Ok(Self::new(buffer, device.clone(), dtype)) } fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> { - if !(sum_dims.len() == 1 - && sum_dims[0] == layout.shape().rank() - 1 - && layout.is_contiguous() - && layout.start_offset() == 0) - { - crate::bail!("Non contiguous reduce op not supported yet"); - } let device = self.device.clone(); let src_stride = layout.stride(); let src_dims = layout.shape().dims(); - let src_el: usize = src_dims.iter().product(); // Source dims and strides with the sum dims at the end. let mut dims = vec![]; let mut stride = vec![]; @@ -191,53 +512,77 @@ impl BackendStorage for MetalStorage { // The reduction loop requires the shared array to be properly initialized and for // this we want the number of threads to be a power of two. let (name, check_empty, return_index) = match (op, self.dtype) { - (ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false), - (ReduceOp::Min, DType::F32) => ("fast_min_float", true, false), - (ReduceOp::Max, DType::F32) => ("fast_max_float", true, false), - (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true), - (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true), - _ => crate::bail!("Reduce op for non float"), + (ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false), + (ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false), + (ReduceOp::Max, DType::F32) => ("fast_max_f32_strided", true, false), + (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32_strided", true, true), + (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32_strided", true, true), + (ReduceOp::Sum, DType::U32) => ("fast_sum_u32_strided", false, false), + (ReduceOp::Min, DType::U32) => ("fast_min_u32_strided", true, false), + (ReduceOp::Max, DType::U32) => ("fast_max_u32_strided", true, false), + (ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32_strided", true, true), + (ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32_strided", true, true), + (ReduceOp::Sum, DType::F16) => ("fast_sum_f16_strided", false, false), + (ReduceOp::Min, DType::F16) => ("fast_min_f16_strided", true, false), + (ReduceOp::Max, DType::F16) => ("fast_max_f16_strided", true, false), + (ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16_strided", true, true), + (ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16_strided", true, true), + (ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16_strided", false, false), + (ReduceOp::Min, DType::BF16) => ("fast_min_bf16_strided", true, false), + (ReduceOp::Max, DType::BF16) => ("fast_max_bf16_strided", true, false), + (ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16_strided", true, true), + (ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16_strided", true, true), + (k, dtype) => crate::bail!("Reduce op for non float {k:?} {dtype:?}"), }; if check_empty && layout.shape().elem_count() == 0 { Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? } let dtype = if return_index { DType::U32 } else { self.dtype }; - let mut buffer = device.new_buffer(dst_el, dtype); - let command_buffer = self.device.command_queue.new_command_buffer(); - candle_metal_kernels::call_reduce_contiguous( + let buffer = device.new_buffer(dst_el, dtype, "reduce")?; + let command_buffer = self.device.command_buffer()?; + candle_metal_kernels::call_reduce_strided( &device.device, &command_buffer, &device.kernels, name, - src_el, + &dims, + &stride, dst_el, &self.buffer, - &mut buffer, + layout.start_offset() * self.dtype.size_in_bytes(), + &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); - Ok(Self { - buffer, - device, - dtype, - }) + Ok(Self::new(buffer, device, dtype)) } - fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> { - crate::bail!("cmp metal") + fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> { + let name = match op { + CmpOp::Eq => "eq", + CmpOp::Ne => "ne", + CmpOp::Le => "le", + CmpOp::Ge => "ge", + CmpOp::Lt => "lt", + CmpOp::Gt => "gt", + }; + self.binary(name, rhs, lhs_l, rhs_l) } fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> { let device = self.device(); let shape = layout.shape(); let el_count = shape.elem_count(); - let mut buffer = device.new_buffer(el_count, dtype); - let command_buffer = device.command_queue.new_command_buffer(); - if layout.is_contiguous() { + let buffer = device.new_buffer(el_count, dtype, "todtype")?; + let command_buffer = device.command_buffer()?; + if layout.is_contiguous() && layout.start_offset() == 0 { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32", + (DType::U32, DType::U8) => "cast_u32_u8", + (DType::U8, DType::U32) => "cast_u8_u32", + (DType::U8, DType::F32) => "cast_u8_f32", + (DType::F32, DType::F16) => "cast_f32_f16", + (DType::F16, DType::F32) => "cast_f16_f32", (left, right) => crate::bail!("to dtype {left:?} - {right:?}"), }; candle_metal_kernels::call_cast_contiguous( @@ -247,24 +592,35 @@ impl BackendStorage for MetalStorage { kernel_name, el_count, &self.buffer, - &mut buffer, + layout.start_offset() * self.dtype.size_in_bytes(), + &buffer, ) .map_err(MetalError::from)?; } else { - crate::bail!( - "TODO Implement the kernel calling cast {:?}-{:?}", - self.dtype, - dtype - ); + let kernel_name = match (self.dtype, dtype) { + (DType::U32, DType::F32) => "cast_u32_f32_strided", + (DType::U32, DType::U8) => "cast_u32_u8_strided", + (DType::U8, DType::U32) => "cast_u8_u32_strided", + (DType::U8, DType::F32) => "cast_u8_f32_strided", + (DType::F32, DType::F16) => "cast_f32_f16_strided", + (DType::F16, DType::F32) => "cast_f16_f32_strided", + (left, right) => crate::bail!("to dtype {left:?} - {right:?}"), + }; + candle_metal_kernels::call_cast_strided( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * self.dtype.size_in_bytes(), + &buffer, + ) + .map_err(MetalError::from)?; } - - command_buffer.commit(); - command_buffer.wait_until_completed(); - Ok(Self { - buffer, - device: device.clone(), - dtype, - }) + command_buffer.set_label("to_dtype"); + Ok(Self::new(buffer, device.clone(), dtype)) } fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> { @@ -272,8 +628,9 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let shape = layout.shape(); let el_count = shape.elem_count(); - let mut buffer = device.new_buffer(el_count, dtype); - let command_buffer = device.command_queue.new_command_buffer(); + let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?; + let command_buffer = device.command_buffer()?; + command_buffer.set_label(B::KERNEL); if layout.is_contiguous() && layout.start_offset() == 0 { use candle_metal_kernels::unary::contiguous; @@ -285,6 +642,27 @@ impl BackendStorage for MetalStorage { ("uneg", DType::F32) => contiguous::neg::FLOAT, ("uexp", DType::F32) => contiguous::exp::FLOAT, ("ulog", DType::F32) => contiguous::log::FLOAT, + ("ugelu", DType::F32) => contiguous::gelu::FLOAT, + ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT, + ("uerf", DType::F32) => contiguous::erf::FLOAT, + ("uceil", DType::F32) => contiguous::ceil::FLOAT, + ("ufloor", DType::F32) => contiguous::floor::FLOAT, + ("uround", DType::F32) => contiguous::round::FLOAT, + ("utanh", DType::F32) => contiguous::tanh::FLOAT, + ("ucos", DType::F16) => contiguous::cos::HALF, + ("usin", DType::F16) => contiguous::sin::HALF, + ("usqr", DType::F16) => contiguous::sqr::HALF, + ("usqrt", DType::F16) => contiguous::sqrt::HALF, + ("uneg", DType::F16) => contiguous::neg::HALF, + ("uexp", DType::F16) => contiguous::exp::HALF, + ("ulog", DType::F16) => contiguous::log::HALF, + ("ugelu", DType::F16) => contiguous::gelu::HALF, + ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF, + ("uerf", DType::F16) => contiguous::erf::HALF, + ("uceil", DType::F16) => contiguous::ceil::HALF, + ("ufloor", DType::F16) => contiguous::floor::HALF, + ("uround", DType::F16) => contiguous::round::HALF, + ("utanh", DType::F16) => contiguous::tanh::HALF, (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_unary_contiguous( @@ -294,95 +672,64 @@ impl BackendStorage for MetalStorage { kernel_name, el_count, &self.buffer, - &mut buffer, - ) - .map_err(MetalError::from)?; - } else { - crate::bail!("TODO Implement the kernel calling {}", B::KERNEL); - } - command_buffer.commit(); - command_buffer.wait_until_completed(); - - Ok(Self { - buffer, - device: device.clone(), - dtype, - }) - } - - fn binary_impl<B: BinaryOpT>( - &self, - rhs: &Self, - lhs_l: &Layout, - rhs_l: &Layout, - ) -> Result<Self> { - let device = self.device(); - let dtype = self.dtype; - let shape = lhs_l.shape(); - let el_count = shape.elem_count(); - let mut buffer = device.new_buffer(el_count, dtype); - let command_buffer = device.command_queue.new_command_buffer(); - if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) - && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) - { - use candle_metal_kernels::binary::contiguous; - - let kernel_name = match (B::KERNEL, dtype) { - ("add", DType::F32) => contiguous::add::FLOAT, - ("badd", DType::F32) => contiguous::add::FLOAT, - ("sub", DType::F32) => contiguous::sub::FLOAT, - ("bsub", DType::F32) => contiguous::sub::FLOAT, - ("mul", DType::F32) => contiguous::mul::FLOAT, - ("bmul", DType::F32) => contiguous::mul::FLOAT, - ("div", DType::F32) => contiguous::div::FLOAT, - ("bdiv", DType::F32) => contiguous::div::FLOAT, - (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), - }; - candle_metal_kernels::call_binary_contiguous( - &device.device, - &command_buffer, - &device.kernels, - kernel_name, - el_count, - &self.buffer, - &rhs.buffer, - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; } else { - use candle_metal_kernels::binary::strided; - + use candle_metal_kernels::unary::strided; let kernel_name = match (B::KERNEL, dtype) { - ("badd", DType::F32) => strided::add::FLOAT, - ("bsub", DType::F32) => strided::sub::FLOAT, - ("bmul", DType::F32) => strided::mul::FLOAT, - ("bdiv", DType::F32) => strided::div::FLOAT, + ("ucos", DType::F32) => strided::cos::FLOAT, + ("usin", DType::F32) => strided::sin::FLOAT, + ("usqr", DType::F32) => strided::sqr::FLOAT, + ("usqrt", DType::F32) => strided::sqrt::FLOAT, + ("uneg", DType::F32) => strided::neg::FLOAT, + ("uexp", DType::F32) => strided::exp::FLOAT, + ("ulog", DType::F32) => strided::log::FLOAT, + ("ugelu", DType::F32) => strided::gelu::FLOAT, + ("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT, + ("uerf", DType::F32) => strided::erf::FLOAT, + ("uceil", DType::F32) => strided::ceil::FLOAT, + ("ufloor", DType::F32) => strided::floor::FLOAT, + ("uround", DType::F32) => strided::round::FLOAT, + ("ucos", DType::F16) => strided::cos::HALF, + ("usin", DType::F16) => strided::sin::HALF, + ("usqr", DType::F16) => strided::sqr::HALF, + ("usqrt", DType::F16) => strided::sqrt::HALF, + ("uneg", DType::F16) => strided::neg::HALF, + ("uexp", DType::F16) => strided::exp::HALF, + ("ulog", DType::F16) => strided::log::HALF, + ("ugelu", DType::F16) => strided::gelu::HALF, + ("ugelu_erf", DType::F16) => strided::gelu_erf::HALF, + ("uerf", DType::F16) => strided::erf::HALF, + ("uceil", DType::F16) => strided::ceil::HALF, + ("ufloor", DType::F16) => strided::floor::HALF, + ("uround", DType::F16) => strided::round::HALF, (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; - candle_metal_kernels::call_binary_strided( + candle_metal_kernels::call_unary_strided( &device.device, &command_buffer, &device.kernels, kernel_name, - lhs_l.dims(), + layout.dims(), &self.buffer, - &lhs_l.stride(), - lhs_l.start_offset() * self.dtype.size_in_bytes(), - &rhs.buffer, - &rhs_l.stride(), - rhs_l.start_offset() * rhs.dtype.size_in_bytes(), - &mut buffer, + layout.stride(), + layout.start_offset() * self.dtype.size_in_bytes(), + &buffer, + 0, ) .map_err(MetalError::from)?; } - command_buffer.commit(); - command_buffer.wait_until_completed(); + Ok(Self::new(buffer, device.clone(), dtype)) + } - Ok(Self { - buffer, - device: device.clone(), - dtype, - }) + fn binary_impl<B: BinaryOpT>( + &self, + rhs: &Self, + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result<Self> { + self.binary(B::KERNEL, rhs, lhs_l, rhs_l) } fn where_cond( @@ -398,14 +745,26 @@ impl BackendStorage for MetalStorage { let dims = shape.dims(); let el = shape.elem_count(); let dtype = t.dtype; - let mut buffer = self.device.new_buffer(el, dtype); - let command_buffer = self.device.command_queue.new_command_buffer(); + let buffer = self.device.new_buffer(el, dtype, "where")?; + let command_buffer = self.device.command_buffer()?; + if t.dtype() != f.dtype() { + crate::bail!( + "Invalid where: different dtypes for values {:?} != {:?}", + t.dtype(), + f.dtype() + ); + } + let name = match (self.dtype, t.dtype()) { + (DType::U8, DType::F32) => "where_u8_f32", + (DType::U8, DType::F16) => "where_u8_f16", + (left, right) => crate::bail!("where {left:?} - {right:?} not implemented"), + }; candle_metal_kernels::call_where_cond_strided( &device.device, &command_buffer, &device.kernels, - "where_u8_f32", - &dims, + name, + dims, &self.buffer, ( layout.stride(), @@ -415,16 +774,10 @@ impl BackendStorage for MetalStorage { (&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()), &f.buffer, (&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()), - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); - Ok(Self { - buffer, - device, - dtype, - }) + Ok(Self::new(buffer, device, dtype)) } fn conv1d( @@ -483,20 +836,84 @@ impl BackendStorage for MetalStorage { crate::bail!("upsample_nearest2d metal") } - fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> { - crate::bail!("gather metal") + fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> { + let (ids_o1, _) = match ids_l.contiguous_offsets() { + Some(o12) => o12, + None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?, + }; + let ids_el = ids_l.dims()[dim]; + let dst_el = ids_l.shape().elem_count(); + let dtype = self.dtype; + let device = self.device(); + let buffer = device.new_buffer(dst_el, dtype, "index_select")?; + let name = match (ids.dtype, self.dtype) { + (DType::U32, DType::F32) => "gather_u32_f32", + (DType::U32, DType::F16) => "gather_u32_f16", + (left, right) => crate::bail!("gather metal {left:?} {right:?} not implemented"), + }; + let command_buffer = self.device.command_buffer()?; + candle_metal_kernels::call_gather( + &device.device, + &command_buffer, + &self.device.kernels, + name, + src_l.dims(), + ids_el, + dim, + &self.buffer, + src_l.start_offset() * dtype.size_in_bytes(), + &ids.buffer, + ids_o1 * ids.dtype.size_in_bytes(), + &buffer, + ) + .map_err(MetalError::from)?; + Ok(Self::new(buffer, device.clone(), dtype)) } fn scatter_add( &self, - _: &Layout, - _: &Self, - _: &Layout, - _: &Self, - _: &Layout, - _: usize, + l: &Layout, + ids: &Self, + ids_l: &Layout, + src: &Self, + src_l: &Layout, + dim: usize, ) -> Result<Self> { - crate::bail!("scatter_add metal") + let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?; + self.copy_strided_src(&mut acc, 0, l)?; + let (ids_offset, _) = match ids_l.contiguous_offsets() { + Some(o12) => o12, + None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?, + }; + let src_offset = match src_l.contiguous_offsets() { + Some((o1, _)) => o1, + None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?, + }; + let name = match (ids.dtype, self.dtype) { + (DType::U32, DType::F32) => "sa_u32_f32", + _ => Err(MetalError::UnexpectedDType { + msg: "scatter-add ids should be u8/u32/i64", + expected: DType::U32, + got: ids.dtype(), + })?, + }; + let command_buffer = self.device.command_buffer()?; + candle_metal_kernels::call_scatter_add( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + src_l.dims(), + l.dims(), + dim, + &src.buffer, + src_offset * src.dtype.size_in_bytes(), + &ids.buffer, + ids_offset * ids.dtype.size_in_bytes(), + &acc.buffer, + ) + .map_err(MetalError::from)?; + Ok(acc) } fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> { @@ -513,12 +930,13 @@ impl BackendStorage for MetalStorage { let dst_el = ids_el * left_size * right_size; let dtype = self.dtype; let device = self.device(); - let mut buffer = device.new_buffer(dst_el, dtype); + let buffer = device.new_buffer(dst_el, dtype, "index_select")?; let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "is_u32_f32", + (DType::U32, DType::F16) => "is_u32_f16", (left, right) => crate::bail!("index select metal {left:?} {right:?}"), }; - let command_buffer = self.device.command_queue.new_command_buffer(); + let command_buffer = self.device.command_buffer()?; candle_metal_kernels::call_index_select( &device.device, &command_buffer, @@ -529,30 +947,58 @@ impl BackendStorage for MetalStorage { dim, &self.buffer, &ids.buffer, - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); - Ok(Self { - buffer, - device: device.clone(), - dtype, - }) + Ok(Self::new(buffer, device.clone(), dtype)) } fn index_add( &self, - _: &Layout, - _: &Self, - _: &Layout, - _: &Self, - _: &Layout, - _: usize, + l: &Layout, + ids: &Self, + ids_l: &Layout, + src: &Self, + src_l: &Layout, + dim: usize, ) -> Result<Self> { - crate::bail!("index_add metal") + let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?; + self.copy_strided_src(&mut acc, 0, l)?; + let (ids_offset, _) = match ids_l.contiguous_offsets() { + Some(o12) => o12, + None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + let src_offset = match src_l.contiguous_offsets() { + Some((o1, _)) => o1, + None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + let name = match (ids.dtype, self.dtype) { + (DType::U32, DType::F32) => "ia_u32_f32", + _ => Err(MetalError::UnexpectedDType { + msg: "index-add ids should be u8/u32/i64", + expected: DType::U32, + got: ids.dtype(), + })?, + }; + let command_buffer = self.device.command_buffer()?; + candle_metal_kernels::call_index_add( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + src_l.dims(), + l.dims(), + ids_l.dims(), + dim, + &src.buffer, + src_offset * src.dtype.size_in_bytes(), + &ids.buffer, + ids_offset * ids.dtype.size_in_bytes(), + &acc.buffer, + ) + .map_err(MetalError::from)?; + Ok(acc) } - fn matmul( &self, rhs: &Self, @@ -560,147 +1006,81 @@ impl BackendStorage for MetalStorage { lhs_l: &Layout, rhs_l: &Layout, ) -> Result<Self> { - // Create descriptors - use metal::mps::matrix::*; - let type_id = metal::mps::MPS_FLOATBIT_ENCODING | 32; - let size = core::mem::size_of::<f32>() as NSUInteger; - - let elem_count = b * m * n; - - let lhs_stride = lhs_l.stride(); - let rhs_stride = rhs_l.stride(); - let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; - let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; - let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; - let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; - // The a tensor has dims batching, k, n (rhs) - let transpose_left = if lhs_m1 == 1 && lhs_m2 == k { - false - } else if lhs_m1 == m && lhs_m2 == 1 { - true - } else { - Err(MetalError::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - mnk: (m, n, k), - })? - }; - let transpose_right = if rhs_m1 == 1 && rhs_m2 == n { - false - } else if rhs_m1 == k && rhs_m2 == 1 { - true - } else { - Err(MetalError::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - mnk: (m, n, k), - })? - }; - - let b = b as NSUInteger; - let m = m as NSUInteger; - let n = n as NSUInteger; - let k = k as NSUInteger; - - let left_descriptor = if transpose_left { - MatrixDescriptor::init_single(k, m, m * size, type_id) - } else { - MatrixDescriptor::init_single(m, k, k * size, type_id) - }; - let right_descriptor = if transpose_right { - MatrixDescriptor::init_single(n, k, k * size, type_id) - } else { - MatrixDescriptor::init_single(k, n, n * size, type_id) + let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?; + let name = match self.dtype { + DType::F32 => "sgemm", + DType::F16 => "hgemm", + dtype => { + return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into()) + } }; - let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id); - - // Create matrix objects - let left_matrix = Matrix::init_with_buffer_descriptor(&self.buffer, 0, &left_descriptor) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; - let right_matrix = Matrix::init_with_buffer_descriptor(&rhs.buffer, 0, &right_descriptor) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; - - let out_buffer = self.device.new_buffer(elem_count, self.dtype); - let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, 0, &result_descriptor) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; - - let alpha = 1.0f64; - let beta = 0.0f64; - // Create kernel - let matrix_multiplication = MatrixMultiplication::init( - &self.device, - transpose_left, - transpose_right, - m, - n, - k, - alpha, - beta, - ) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; - - matrix_multiplication.set_batch_size(b); - - // Encode kernel to command buffer - let command_buffer = self.device.command_queue.new_command_buffer(); - matrix_multiplication.encode_to_command_buffer( - command_buffer, - &left_matrix, - &right_matrix, - &result_matrix, - ); - command_buffer.commit(); - command_buffer.wait_until_completed(); - - Ok(Self { - buffer: out_buffer, - device: self.device.clone(), - dtype: self.dtype(), - }) - } - fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { - let src_shape = src_l.shape(); - let el_count = src_shape.elem_count(); - if el_count == 0 { - return Ok(()); - } - let command_buffer = self.device.command_queue.new_command_buffer(); - let kernel_name = match self.dtype { - DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, - DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, - DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, - dtype => crate::bail!("copy_strided not implemented for {dtype:?}"), - }; - candle_metal_kernels::call_unary_strided( + let command_buffer = self.device.command_buffer()?; + command_buffer.set_label("matmul"); + candle_metal_kernels::call_gemm( &self.device.device, &command_buffer, &self.device.kernels, - kernel_name, - src_l.dims(), + name, + (b, m, n, k), + lhs_l.stride(), + lhs_l.start_offset() * self.dtype.size_in_bytes(), &self.buffer, - &src_l.stride(), - src_l.start_offset() * self.dtype.size_in_bytes(), - &mut dst.buffer, - dst_offset, + rhs_l.stride(), + rhs_l.start_offset() * rhs.dtype.size_in_bytes(), + &rhs.buffer, + &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); + Ok(Self::new(buffer, self.device.clone(), self.dtype())) + } + + fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { + let command_buffer = self.device.command_buffer()?; + if src_l.is_contiguous() && self.dtype == dst.dtype() { + command_buffer.set_label("copy_contiguous"); + let blit = command_buffer.new_blit_command_encoder(); + blit.set_label("copy_contiguous"); + let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger; + let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger; + let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger; + blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length); + blit.end_encoding(); + } else { + let src_shape = src_l.shape(); + let el_count = src_shape.elem_count(); + if el_count == 0 { + return Ok(()); + } + let kernel_name = match self.dtype { + DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, + DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, + DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, + DType::U32 => candle_metal_kernels::unary::strided::copy::U32, + DType::U8 => candle_metal_kernels::unary::strided::copy::U8, + dtype => crate::bail!("copy_strided not implemented for {dtype:?}"), + }; + candle_metal_kernels::call_unary_strided( + &self.device.device, + &command_buffer, + &self.device.kernels, + kernel_name, + src_l.dims(), + &self.buffer, + src_l.stride(), + src_l.start_offset() * self.dtype.size_in_bytes(), + &dst.buffer, + dst_offset * dst.dtype.size_in_bytes(), + ) + .map_err(MetalError::from)?; + command_buffer.set_label("copy_strided"); + } Ok(()) } } impl MetalStorage { - pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self { + pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: DType) -> Self { Self { buffer, device, @@ -711,6 +1091,111 @@ impl MetalStorage { pub fn buffer(&self) -> &Buffer { &self.buffer } + + pub fn binary( + &self, + op: &'static str, + rhs: &Self, + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result<Self> { + let device = self.device(); + let shape = lhs_l.shape(); + let el_count = shape.elem_count(); + let command_buffer = device.command_buffer()?; + let (buffer, dtype) = if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) + && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) + && &op[..1] != "b" + { + use candle_metal_kernels::binary::contiguous; + + let (kernel_name, dtype) = match (op, self.dtype) { + ("add", DType::F32) => (contiguous::add::FLOAT, self.dtype), + ("sub", DType::F32) => (contiguous::sub::FLOAT, self.dtype), + ("mul", DType::F32) => (contiguous::mul::FLOAT, self.dtype), + ("div", DType::F32) => (contiguous::div::FLOAT, self.dtype), + ("eq", DType::F32) => (contiguous::eq::FLOAT, DType::U8), + ("ne", DType::F32) => (contiguous::ne::FLOAT, DType::U8), + ("le", DType::F32) => (contiguous::le::FLOAT, DType::U8), + ("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8), + ("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8), + ("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8), + ("add", DType::F16) => (contiguous::add::HALF, self.dtype), + ("sub", DType::F16) => (contiguous::sub::HALF, self.dtype), + ("mul", DType::F16) => (contiguous::mul::HALF, self.dtype), + ("div", DType::F16) => (contiguous::div::HALF, self.dtype), + ("eq", DType::F16) => (contiguous::eq::HALF, DType::U8), + ("ne", DType::F16) => (contiguous::ne::HALF, DType::U8), + ("le", DType::F16) => (contiguous::le::HALF, DType::U8), + ("lt", DType::F16) => (contiguous::lt::HALF, DType::U8), + ("ge", DType::F16) => (contiguous::ge::HALF, DType::U8), + ("gt", DType::F16) => (contiguous::gt::HALF, DType::U8), + (name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"), + }; + let buffer = device.new_buffer(el_count, dtype, op)?; + candle_metal_kernels::call_binary_contiguous( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + &self.buffer, + &rhs.buffer, + &buffer, + ) + .map_err(MetalError::from)?; + (buffer, dtype) + } else { + use candle_metal_kernels::binary::strided; + + let (kernel_name, dtype) = match (op, self.dtype) { + ("badd", DType::F32) => (strided::add::FLOAT, self.dtype), + ("bsub", DType::F32) => (strided::sub::FLOAT, self.dtype), + ("bmul", DType::F32) => (strided::mul::FLOAT, self.dtype), + ("bdiv", DType::F32) => (strided::div::FLOAT, self.dtype), + ("bminimum", DType::F32) => (strided::min::FLOAT, self.dtype), + ("bmaximum", DType::F32) => (strided::max::FLOAT, self.dtype), + ("eq", DType::F32) => (strided::eq::FLOAT, DType::U8), + ("ne", DType::F32) => (strided::ne::FLOAT, DType::U8), + ("le", DType::F32) => (strided::le::FLOAT, DType::U8), + ("lt", DType::F32) => (strided::lt::FLOAT, DType::U8), + ("ge", DType::F32) => (strided::ge::FLOAT, DType::U8), + ("gt", DType::F32) => (strided::gt::FLOAT, DType::U8), + ("badd", DType::F16) => (strided::add::HALF, self.dtype), + ("bsub", DType::F16) => (strided::sub::HALF, self.dtype), + ("bmul", DType::F16) => (strided::mul::HALF, self.dtype), + ("bdiv", DType::F16) => (strided::div::HALF, self.dtype), + ("bminimum", DType::F16) => (strided::min::HALF, self.dtype), + ("bmaximum", DType::F16) => (strided::max::HALF, self.dtype), + ("eq", DType::F16) => (strided::eq::HALF, DType::U8), + ("ne", DType::F16) => (strided::ne::HALF, DType::U8), + ("le", DType::F16) => (strided::le::HALF, DType::U8), + ("lt", DType::F16) => (strided::lt::HALF, DType::U8), + ("ge", DType::F16) => (strided::ge::HALF, DType::U8), + ("gt", DType::F16) => (strided::gt::HALF, DType::U8), + (name, dtype) => crate::bail!("Binary strided {name} - {dtype:?} not implemented"), + }; + let buffer = device.new_buffer(el_count, dtype, op)?; + candle_metal_kernels::call_binary_strided( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + lhs_l.dims(), + &self.buffer, + lhs_l.stride(), + lhs_l.start_offset() * self.dtype.size_in_bytes(), + &rhs.buffer, + rhs_l.stride(), + rhs_l.start_offset() * rhs.dtype.size_in_bytes(), + &buffer, + ) + .map_err(MetalError::from)?; + (buffer, dtype) + }; + command_buffer.set_label("binary"); + Ok(Self::new(buffer, device.clone(), dtype)) + } } impl BackendDevice for MetalDevice { @@ -718,12 +1203,26 @@ impl BackendDevice for MetalDevice { fn new(ordinal: usize) -> Result<Self> { let device = metal::Device::all().swap_remove(ordinal); - let command_queue = device.new_command_queue(); - let kernels = Arc::new(Kernels::new()); + let command_buffer = command_queue.new_command_buffer().to_owned(); + command_buffer.enqueue(); + let command_buffer = Arc::new(RwLock::new(command_buffer)); + let command_buffer_index = Arc::new(RwLock::new(0)); + let fence = device.new_fence(); + let kernels = Arc::new(Kernels::new(fence.clone())); + let buffers = Arc::new(RwLock::new(HashMap::new())); + let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") { + Ok(val) => val.parse()?, + _ => 20, + }; Ok(Self { device, + fence, command_queue, + command_buffer, + command_buffer_index, + compute_per_buffer, + buffers, kernels, }) } @@ -743,9 +1242,22 @@ impl BackendDevice for MetalDevice { } fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> { - // TODO Is there a faster way ? - let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?; - self.storage_from_cpu_storage(&cpu_storage) + let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros")?; + let command_buffer = self.command_buffer()?; + command_buffer.set_label("zeros"); + let blit = command_buffer.new_blit_command_encoder(); + blit.wait_for_fence(&self.fence); + blit.fill_buffer( + &buffer, + metal::NSRange { + location: 0, + length: buffer.length(), + }, + 0, + ); + blit.update_fence(&self.fence); + blit.end_encoding(); + Ok(MetalStorage::new(buffer, self.clone(), dtype)) } fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> { @@ -755,49 +1267,16 @@ impl BackendDevice for MetalDevice { } fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> { - let option = metal::MTLResourceOptions::StorageModeManaged; let buffer = match storage { - CpuStorage::U8(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<u8>()) as NSUInteger, - option, - ), - CpuStorage::U32(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<u32>()) as NSUInteger, - option, - ), - CpuStorage::I64(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<i64>()) as NSUInteger, - option, - ), - CpuStorage::BF16(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<bf16>()) as NSUInteger, - option, - ), - CpuStorage::F16(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<f16>()) as NSUInteger, - option, - ), - CpuStorage::F32(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<f32>()) as NSUInteger, - option, - ), - CpuStorage::F64(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<f64>()) as NSUInteger, - option, - ), - }; - Ok(Self::Storage { - buffer, - device: self.clone(), - dtype: storage.dtype(), - }) + CpuStorage::U8(storage) => self.new_buffer_with_data(storage), + CpuStorage::U32(storage) => self.new_buffer_with_data(storage), + CpuStorage::I64(storage) => self.new_buffer_with_data(storage), + CpuStorage::BF16(storage) => self.new_buffer_with_data(storage), + CpuStorage::F16(storage) => self.new_buffer_with_data(storage), + CpuStorage::F32(storage) => self.new_buffer_with_data(storage), + CpuStorage::F64(storage) => self.new_buffer_with_data(storage), + }?; + Ok(Self::Storage::new(buffer, self.clone(), storage.dtype())) } fn rand_uniform( @@ -824,3 +1303,10 @@ impl BackendDevice for MetalDevice { self.storage_from_cpu_storage(&cpu_storage) } } + +fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> { + let ptr = buffer.contents() as *const T; + assert!(!ptr.is_null()); + let slice = unsafe { std::slice::from_raw_parts(ptr, n) }; + slice.to_vec() +} diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index e6e7b415..f15f8c1c 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1877,10 +1877,7 @@ impl Tensor { Storage::Metal(metal.storage_from_cpu_storage(storage)?) } (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), - (Storage::Metal(storage), Device::Cpu) => { - println!("{storage:?} - {:?}", storage.to_cpu_storage()?); - Storage::Cpu(storage.to_cpu_storage()?) - } + (Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), (Storage::Cuda(storage), Device::Cuda(cuda)) => { // TODO: Avoid passing through the cpu storage here, especially if the gpu ids // are the same. |