diff options
Diffstat (limited to 'candle-core/src/metal_backend.rs')
-rw-r--r-- | candle-core/src/metal_backend.rs | 702 |
1 files changed, 475 insertions, 227 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 0b72f080..12f56d50 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -4,11 +4,13 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; use candle_metal_kernels::Kernels; -use core::mem; -use half::{bf16, f16}; +use half::f16; use metal; -use metal::{Buffer, CommandQueue, MTLResourceOptions, NSUInteger}; -use std::sync::Arc; +use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication}; +use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; +use std::collections::HashMap; +use std::path::Path; +use std::sync::{Arc, RwLock}; /// Metal related errors #[derive(thiserror::Error, Debug)] @@ -36,7 +38,9 @@ impl From<String> for MetalError { pub struct MetalDevice { device: metal::Device, command_queue: metal::CommandQueue, + command_buffer: Arc<RwLock<metal::CommandBuffer>>, kernels: Arc<candle_metal_kernels::Kernels>, + buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>, } impl std::fmt::Debug for MetalDevice { @@ -58,10 +62,48 @@ impl MetalDevice { self.registry_id() } + pub fn metal_device(&self) -> &metal::Device { + &self.device + } + pub fn command_queue(&self) -> &CommandQueue { &self.command_queue } + pub fn command_buffer(&self) -> std::sync::RwLockReadGuard<CommandBuffer> { + self.command_buffer.try_read().unwrap() + } + + pub fn commit(&self) { + let mut old = self.command_buffer.try_write().unwrap(); + match old.status() { + metal::MTLCommandBufferStatus::NotEnqueued + | metal::MTLCommandBufferStatus::Enqueued => { + old.commit(); + let command_buffer = self.command_queue.new_command_buffer().to_owned(); + *old = command_buffer; + } + _ => {} + } + } + + pub fn wait_until_completed(&self) { + let mut old = self.command_buffer.try_write().unwrap(); + match old.status() { + metal::MTLCommandBufferStatus::NotEnqueued + | metal::MTLCommandBufferStatus::Enqueued => { + old.commit(); + old.wait_until_completed(); + } + metal::MTLCommandBufferStatus::Committed | metal::MTLCommandBufferStatus::Scheduled => { + old.wait_until_completed(); + } + _ => {} + } + let command_buffer = self.command_queue.new_command_buffer().to_owned(); + *old = command_buffer; + } + pub fn kernels(&self) -> &Kernels { &self.kernels } @@ -70,16 +112,107 @@ impl MetalDevice { &self.device } - pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { + pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Arc<Buffer> { let size = (element_count * dtype.size_in_bytes()) as NSUInteger; - self.device - .new_buffer(size, MTLResourceOptions::StorageModeManaged) + self._new_buffer(size, MTLResourceOptions::StorageModePrivate) + } + + fn _new_buffer(&self, size: NSUInteger, option: MTLResourceOptions) -> Arc<Buffer> { + let mut buffers = self.buffers.try_write().unwrap(); + let subbuffers = buffers.entry((size, option)).or_insert(vec![]); + + for sub in &mut *subbuffers { + if Arc::strong_count(sub) == 1 { + return sub.clone(); + } + } + let new_buffer = self.device.new_buffer(size as NSUInteger, option); + let new_buffer = Arc::new(new_buffer); + subbuffers.push(new_buffer.clone()); + new_buffer + } + + pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> { + self._new_buffer(size, MTLResourceOptions::StorageModeManaged) + } + + pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> { + let size = core::mem::size_of_val(data) as NSUInteger; + let tmp = self.device.new_buffer_with_data( + data.as_ptr() as *const core::ffi::c_void, + size, + metal::MTLResourceOptions::StorageModeManaged, + ); + let real = self._new_buffer(size, metal::MTLResourceOptions::StorageModePrivate); + { + let command = self.command_buffer(); + let blit = command.new_blit_command_encoder(); + blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); + blit.end_encoding(); + } + // This is necessary, for mmaped safetensors + // Because of the unsafe slice cast we're doing. + // The slice might not live long enough for metal + // To actually fill the GPU buffer. + // Putting this wait forces the GPU buffer to be filled + // with the actual data allowing the CPU storage todo + // deallocate properly. + self.wait_until_completed(); + real + } + + pub fn new_matrix( + &self, + (b, m, n): (NSUInteger, NSUInteger, NSUInteger), + size: NSUInteger, + type_id: u32, + dtype: DType, + ) -> Result<(Matrix, Arc<Buffer>)> { + let elem_count = (b * m * n) as usize; + let out_buffer = self.new_buffer(elem_count, dtype); + + let result_descriptor = + MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id); + let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, 0, &result_descriptor) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; + Ok((result_matrix, out_buffer)) + } + + pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> { + let capture = metal::CaptureManager::shared(); + let descriptor = metal::CaptureDescriptor::new(); + descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); + descriptor.set_capture_device(&self); + descriptor.set_output_url(path); + + capture + .start_capture(&descriptor) + .map_err(MetalError::from)?; + Ok(()) } } #[derive(Debug, Clone)] pub struct MetalStorage { - buffer: metal::Buffer, + buffer: Arc<metal::Buffer>, + matrices: Arc< + RwLock< + HashMap< + ( + NSUInteger, + NSUInteger, + NSUInteger, + bool, + NSUInteger, + NSUInteger, + u32, + ), + Matrix, + >, + >, + >, device: MetalDevice, dtype: DType, } @@ -108,14 +241,23 @@ impl BackendStorage for MetalStorage { self.dtype ); } + + let buffer = self.device.new_buffer_managed(self.buffer.length()); + let command_buffer = self.device.command_buffer(); + let blit = command_buffer.new_blit_command_encoder(); + blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); + blit.end_encoding(); + drop(command_buffer); + self.device.wait_until_completed(); + match self.dtype { - DType::U8 => Ok(CpuStorage::U8(self.buffer.read_to_vec(length / size))), - DType::U32 => Ok(CpuStorage::U32(self.buffer.read_to_vec(length / size))), - DType::I64 => Ok(CpuStorage::I64(self.buffer.read_to_vec(length / size))), - DType::F16 => Ok(CpuStorage::F16(self.buffer.read_to_vec(length / size))), - DType::BF16 => Ok(CpuStorage::BF16(self.buffer.read_to_vec(length / size))), - DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(length / size))), - DType::F64 => Ok(CpuStorage::F64(self.buffer.read_to_vec(length / size))), + DType::U8 => Ok(CpuStorage::U8(buffer.read_to_vec(length / size))), + DType::U32 => Ok(CpuStorage::U32(buffer.read_to_vec(length / size))), + DType::I64 => Ok(CpuStorage::I64(buffer.read_to_vec(length / size))), + DType::F16 => Ok(CpuStorage::F16(buffer.read_to_vec(length / size))), + DType::BF16 => Ok(CpuStorage::BF16(buffer.read_to_vec(length / size))), + DType::F32 => Ok(CpuStorage::F32(buffer.read_to_vec(length / size))), + DType::F64 => Ok(CpuStorage::F64(buffer.read_to_vec(length / size))), } } @@ -126,30 +268,48 @@ impl BackendStorage for MetalStorage { let el = shape.elem_count(); let dtype = self.dtype; - if layout.is_contiguous() || layout.start_offset() != 0 || dtype != DType::F32 { - crate::bail!("Not contiguous, non-f32 affine is not implemented yet."); + let buffer = device.new_buffer(el, self.dtype); + let command_buffer = self.device.command_buffer(); + if layout.is_contiguous() && layout.start_offset() == 0 { + let name = match self.dtype { + DType::F32 => "affine_float", + DType::F16 => "affine_half", + dtype => crate::bail!("Affine {dtype:?}"), + }; + candle_metal_kernels::call_affine( + &device.device, + &command_buffer, + &device.kernels, + name, + el, + &self.buffer, + &buffer, + mul as f32, + add as f32, + ) + .map_err(MetalError::from)?; + } else { + let name = match self.dtype { + DType::F32 => "affine_float_strided", + DType::F16 => "affine_half_strided", + dtype => crate::bail!("Affine {dtype:?}"), + }; + candle_metal_kernels::call_affine_strided( + &device.device, + &command_buffer, + &device.kernels, + name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * dtype.size_in_bytes(), + &buffer, + mul as f32, + add as f32, + ) + .map_err(MetalError::from)?; } - - let mut buffer = device.new_buffer(el, self.dtype); - let command_buffer = self.device.command_queue.new_command_buffer(); - candle_metal_kernels::call_affine( - &device.device, - &command_buffer, - &device.kernels, - el, - &self.buffer, - &mut buffer, - mul as f32, - add as f32, - ) - .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); - return Ok(Self { - buffer, - device: device.clone(), - dtype, - }); + Ok(Self::new(buffer, device.clone(), dtype)) } fn powf(&self, _: &Layout, _: f64) -> Result<Self> { @@ -163,11 +323,11 @@ impl BackendStorage for MetalStorage { fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> { if !(sum_dims.len() == 1 && sum_dims[0] == layout.shape().rank() - 1 - && layout.is_contiguous() - && layout.start_offset() == 0) + && layout.stride()[sum_dims[0]] == 1) { - crate::bail!("Non contiguous reduce op not supported yet"); + crate::bail!("Non last dim reduce op not supported yet"); } + let device = self.device.clone(); let src_stride = layout.stride(); let src_dims = layout.shape().dims(); @@ -202,8 +362,11 @@ impl BackendStorage for MetalStorage { Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? } let dtype = if return_index { DType::U32 } else { self.dtype }; - let mut buffer = device.new_buffer(dst_el, dtype); - let command_buffer = self.device.command_queue.new_command_buffer(); + if dtype == DType::U32 { + crate::bail!("Implement return index reduce op"); + } + let buffer = device.new_buffer(dst_el, dtype); + let command_buffer = self.device.command_buffer(); candle_metal_kernels::call_reduce_contiguous( &device.device, &command_buffer, @@ -212,17 +375,12 @@ impl BackendStorage for MetalStorage { src_el, dst_el, &self.buffer, - &mut buffer, + layout.start_offset() * self.dtype.size_in_bytes(), + &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); - Ok(Self { - buffer, - device, - dtype, - }) + Ok(Self::new(buffer, device, dtype)) } fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> { @@ -233,11 +391,15 @@ impl BackendStorage for MetalStorage { let device = self.device(); let shape = layout.shape(); let el_count = shape.elem_count(); - let mut buffer = device.new_buffer(el_count, dtype); - let command_buffer = device.command_queue.new_command_buffer(); + let buffer = device.new_buffer(el_count, dtype); + let command_buffer = device.command_buffer(); if layout.is_contiguous() { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32", + (DType::U32, DType::U8) => "cast_u32_u8", + (DType::U8, DType::U32) => "cast_u8_u32", + (DType::F32, DType::F16) => "cast_f32_f16", + (DType::F16, DType::F32) => "cast_f16_f32", (left, right) => crate::bail!("to dtype {left:?} - {right:?}"), }; candle_metal_kernels::call_cast_contiguous( @@ -247,24 +409,34 @@ impl BackendStorage for MetalStorage { kernel_name, el_count, &self.buffer, - &mut buffer, + layout.start_offset() * self.dtype.size_in_bytes(), + &buffer, ) .map_err(MetalError::from)?; } else { - crate::bail!( - "TODO Implement the kernel calling cast {:?}-{:?}", - self.dtype, - dtype - ); + let kernel_name = match (self.dtype, dtype) { + (DType::U32, DType::F32) => "cast_u32_f32_strided", + (DType::U32, DType::U8) => "cast_u32_u8_strided", + (DType::U8, DType::U32) => "cast_u8_u32_strided", + (DType::F32, DType::F16) => "cast_f32_f16_strided", + (DType::F16, DType::F32) => "cast_f16_f32_strided", + (left, right) => crate::bail!("to dtype {left:?} - {right:?}"), + }; + candle_metal_kernels::call_cast_strided( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * self.dtype.size_in_bytes(), + &buffer, + ) + .map_err(MetalError::from)?; } - command_buffer.commit(); - command_buffer.wait_until_completed(); - Ok(Self { - buffer, - device: device.clone(), - dtype, - }) + Ok(Self::new(buffer, device.clone(), dtype)) } fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> { @@ -272,8 +444,8 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let shape = layout.shape(); let el_count = shape.elem_count(); - let mut buffer = device.new_buffer(el_count, dtype); - let command_buffer = device.command_queue.new_command_buffer(); + let buffer = device.new_buffer(el_count, dtype); + let command_buffer = device.command_buffer(); if layout.is_contiguous() && layout.start_offset() == 0 { use candle_metal_kernels::unary::contiguous; @@ -285,6 +457,25 @@ impl BackendStorage for MetalStorage { ("uneg", DType::F32) => contiguous::neg::FLOAT, ("uexp", DType::F32) => contiguous::exp::FLOAT, ("ulog", DType::F32) => contiguous::log::FLOAT, + ("ugelu", DType::F32) => contiguous::gelu::FLOAT, + ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT, + ("uerf", DType::F32) => contiguous::erf::FLOAT, + ("uceil", DType::F32) => contiguous::ceil::FLOAT, + ("ufloor", DType::F32) => contiguous::floor::FLOAT, + ("uround", DType::F32) => contiguous::round::FLOAT, + ("ucos", DType::F16) => contiguous::cos::HALF, + ("usin", DType::F16) => contiguous::sin::HALF, + ("usqr", DType::F16) => contiguous::sqr::HALF, + ("usqrt", DType::F16) => contiguous::sqrt::HALF, + ("uneg", DType::F16) => contiguous::neg::HALF, + ("uexp", DType::F16) => contiguous::exp::HALF, + ("ulog", DType::F16) => contiguous::log::HALF, + ("ugelu", DType::F16) => contiguous::gelu::HALF, + ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF, + ("uerf", DType::F16) => contiguous::erf::HALF, + ("uceil", DType::F16) => contiguous::ceil::HALF, + ("ufloor", DType::F16) => contiguous::floor::HALF, + ("uround", DType::F16) => contiguous::round::HALF, (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_unary_contiguous( @@ -294,20 +485,58 @@ impl BackendStorage for MetalStorage { kernel_name, el_count, &self.buffer, - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; } else { - crate::bail!("TODO Implement the kernel calling {}", B::KERNEL); + use candle_metal_kernels::unary::strided; + let kernel_name = match (B::KERNEL, dtype) { + ("ucos", DType::F32) => strided::cos::FLOAT, + ("usin", DType::F32) => strided::sin::FLOAT, + ("usqr", DType::F32) => strided::sqr::FLOAT, + ("usqrt", DType::F32) => strided::sqrt::FLOAT, + ("uneg", DType::F32) => strided::neg::FLOAT, + ("uexp", DType::F32) => strided::exp::FLOAT, + ("ulog", DType::F32) => strided::log::FLOAT, + ("ugelu", DType::F32) => strided::gelu::FLOAT, + ("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT, + ("uerf", DType::F32) => strided::erf::FLOAT, + ("uceil", DType::F32) => strided::ceil::FLOAT, + ("ufloor", DType::F32) => strided::floor::FLOAT, + ("uround", DType::F32) => strided::round::FLOAT, + ("ucos", DType::F16) => strided::cos::HALF, + ("usin", DType::F16) => strided::sin::HALF, + ("usqr", DType::F16) => strided::sqr::HALF, + ("usqrt", DType::F16) => strided::sqrt::HALF, + ("uneg", DType::F16) => strided::neg::HALF, + ("uexp", DType::F16) => strided::exp::HALF, + ("ulog", DType::F16) => strided::log::HALF, + ("ugelu", DType::F16) => strided::gelu::HALF, + ("ugelu_erf", DType::F16) => strided::gelu_erf::HALF, + ("uerf", DType::F16) => strided::erf::HALF, + ("uceil", DType::F16) => strided::ceil::HALF, + ("ufloor", DType::F16) => strided::floor::HALF, + ("uround", DType::F16) => strided::round::HALF, + (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), + }; + candle_metal_kernels::call_unary_strided( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * self.dtype.size_in_bytes(), + &buffer, + 0, + ) + .map_err(MetalError::from)?; } - command_buffer.commit(); - command_buffer.wait_until_completed(); - - Ok(Self { - buffer, - device: device.clone(), - dtype, - }) + command_buffer.set_label("unary"); + drop(command_buffer); + self.device.commit(); + Ok(Self::new(buffer, device.clone(), dtype)) } fn binary_impl<B: BinaryOpT>( @@ -320,8 +549,8 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let shape = lhs_l.shape(); let el_count = shape.elem_count(); - let mut buffer = device.new_buffer(el_count, dtype); - let command_buffer = device.command_queue.new_command_buffer(); + let buffer = device.new_buffer(el_count, dtype); + let command_buffer = device.command_buffer(); if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) { @@ -336,6 +565,14 @@ impl BackendStorage for MetalStorage { ("bmul", DType::F32) => contiguous::mul::FLOAT, ("div", DType::F32) => contiguous::div::FLOAT, ("bdiv", DType::F32) => contiguous::div::FLOAT, + ("add", DType::F16) => contiguous::add::HALF, + ("badd", DType::F16) => contiguous::add::HALF, + ("sub", DType::F16) => contiguous::sub::HALF, + ("bsub", DType::F16) => contiguous::sub::HALF, + ("mul", DType::F16) => contiguous::mul::HALF, + ("bmul", DType::F16) => contiguous::mul::HALF, + ("div", DType::F16) => contiguous::div::HALF, + ("bdiv", DType::F16) => contiguous::div::HALF, (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_binary_contiguous( @@ -346,7 +583,7 @@ impl BackendStorage for MetalStorage { el_count, &self.buffer, &rhs.buffer, - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; } else { @@ -357,6 +594,10 @@ impl BackendStorage for MetalStorage { ("bsub", DType::F32) => strided::sub::FLOAT, ("bmul", DType::F32) => strided::mul::FLOAT, ("bdiv", DType::F32) => strided::div::FLOAT, + ("badd", DType::F16) => strided::add::HALF, + ("bsub", DType::F16) => strided::sub::HALF, + ("bmul", DType::F16) => strided::mul::HALF, + ("bdiv", DType::F16) => strided::div::HALF, (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_binary_strided( @@ -366,23 +607,19 @@ impl BackendStorage for MetalStorage { kernel_name, lhs_l.dims(), &self.buffer, - &lhs_l.stride(), + lhs_l.stride(), lhs_l.start_offset() * self.dtype.size_in_bytes(), &rhs.buffer, - &rhs_l.stride(), + rhs_l.stride(), rhs_l.start_offset() * rhs.dtype.size_in_bytes(), - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; } - command_buffer.commit(); - command_buffer.wait_until_completed(); - - Ok(Self { - buffer, - device: device.clone(), - dtype, - }) + command_buffer.set_label("binary"); + drop(command_buffer); + self.device.commit(); + Ok(Self::new(buffer, device.clone(), dtype)) } fn where_cond( @@ -398,14 +635,22 @@ impl BackendStorage for MetalStorage { let dims = shape.dims(); let el = shape.elem_count(); let dtype = t.dtype; - let mut buffer = self.device.new_buffer(el, dtype); - let command_buffer = self.device.command_queue.new_command_buffer(); + let buffer = self.device.new_buffer(el, dtype); + let command_buffer = self.device.command_buffer(); + if t.dtype() != f.dtype() { + crate::bail!("Invalid ternary different dtypes for values"); + } + let name = match (self.dtype, t.dtype()) { + (DType::U8, DType::F32) => "where_u8_f32", + (DType::U8, DType::F16) => "where_u8_f16", + (left, right) => crate::bail!("Ternary {left:?} - {right:?} not implemented"), + }; candle_metal_kernels::call_where_cond_strided( &device.device, &command_buffer, &device.kernels, - "where_u8_f32", - &dims, + name, + dims, &self.buffer, ( layout.stride(), @@ -415,16 +660,10 @@ impl BackendStorage for MetalStorage { (&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()), &f.buffer, (&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()), - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); - Ok(Self { - buffer, - device, - dtype, - }) + Ok(Self::new(buffer, device, dtype)) } fn conv1d( @@ -513,12 +752,13 @@ impl BackendStorage for MetalStorage { let dst_el = ids_el * left_size * right_size; let dtype = self.dtype; let device = self.device(); - let mut buffer = device.new_buffer(dst_el, dtype); + let buffer = device.new_buffer(dst_el, dtype); let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "is_u32_f32", + (DType::U32, DType::F16) => "is_u32_f16", (left, right) => crate::bail!("index select metal {left:?} {right:?}"), }; - let command_buffer = self.device.command_queue.new_command_buffer(); + let command_buffer = self.device.command_buffer(); candle_metal_kernels::call_index_select( &device.device, &command_buffer, @@ -529,16 +769,10 @@ impl BackendStorage for MetalStorage { dim, &self.buffer, &ids.buffer, - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); - Ok(Self { - buffer, - device: device.clone(), - dtype, - }) + Ok(Self::new(buffer, device.clone(), dtype)) } fn index_add( @@ -561,11 +795,18 @@ impl BackendStorage for MetalStorage { rhs_l: &Layout, ) -> Result<Self> { // Create descriptors - use metal::mps::matrix::*; - let type_id = metal::mps::MPS_FLOATBIT_ENCODING | 32; - let size = core::mem::size_of::<f32>() as NSUInteger; - let elem_count = b * m * n; + let (type_id, size) = match self.dtype { + DType::F32 => ( + metal::mps::MPS_FLOATBIT_ENCODING | 32, + core::mem::size_of::<f32>() as NSUInteger, + ), + DType::F16 => ( + metal::mps::MPS_FLOATBIT_ENCODING | 16, + core::mem::size_of::<f16>() as NSUInteger, + ), + dtype => todo!("Dtype for matmul {dtype:?} is not supported"), + }; let lhs_stride = lhs_l.stride(); let rhs_stride = rhs_l.stride(); @@ -596,39 +837,30 @@ impl BackendStorage for MetalStorage { mnk: (m, n, k), })? }; - let b = b as NSUInteger; let m = m as NSUInteger; let n = n as NSUInteger; let k = k as NSUInteger; - let left_descriptor = if transpose_left { - MatrixDescriptor::init_single(k, m, m * size, type_id) - } else { - MatrixDescriptor::init_single(m, k, k * size, type_id) - }; - let right_descriptor = if transpose_right { - MatrixDescriptor::init_single(n, k, k * size, type_id) - } else { - MatrixDescriptor::init_single(k, n, n * size, type_id) - }; - let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id); - - // Create matrix objects - let left_matrix = Matrix::init_with_buffer_descriptor(&self.buffer, 0, &left_descriptor) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; - let right_matrix = Matrix::init_with_buffer_descriptor(&rhs.buffer, 0, &right_descriptor) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; + let left_matrix = self.matrix( + (b, m, k), + transpose_left, + size, + lhs_l.start_offset() as NSUInteger * size, + type_id, + )?; + let right_matrix = rhs.matrix( + (b, k, n), + transpose_right, + size, + rhs_l.start_offset() as NSUInteger * size, + type_id, + )?; + let (result_matrix, out_buffer) = + self.device + .new_matrix((b, m, n), size, type_id, self.dtype)?; - let out_buffer = self.device.new_buffer(elem_count, self.dtype); - let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, 0, &result_descriptor) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; + let command_buffer = self.device.command_buffer(); let alpha = 1.0f64; let beta = 0.0f64; @@ -647,70 +879,112 @@ impl BackendStorage for MetalStorage { MetalError::from("Failed to create matrix multiplication kernel".to_string()) })?; - matrix_multiplication.set_batch_size(b); - // Encode kernel to command buffer - let command_buffer = self.device.command_queue.new_command_buffer(); matrix_multiplication.encode_to_command_buffer( - command_buffer, + &command_buffer, &left_matrix, &right_matrix, &result_matrix, ); - command_buffer.commit(); - command_buffer.wait_until_completed(); + command_buffer.set_label("matmul"); + drop(command_buffer); + self.device.commit(); - Ok(Self { - buffer: out_buffer, - device: self.device.clone(), - dtype: self.dtype(), - }) + Ok(Self::new(out_buffer, self.device.clone(), self.dtype())) } fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { - let src_shape = src_l.shape(); - let el_count = src_shape.elem_count(); - if el_count == 0 { - return Ok(()); + let command_buffer = self.device.command_buffer(); + if src_l.is_contiguous() && self.dtype == dst.dtype() { + command_buffer.set_label("copy_contiguous"); + let blit = command_buffer.new_blit_command_encoder(); + let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger; + let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger; + blit.copy_from_buffer( + &self.buffer, + src_offset, + dst.buffer(), + dst_offset, + self.buffer.length() - src_offset, + ); + blit.end_encoding(); + } else { + let src_shape = src_l.shape(); + let el_count = src_shape.elem_count(); + if el_count == 0 { + return Ok(()); + } + let kernel_name = match self.dtype { + DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, + DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, + DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, + DType::U32 => candle_metal_kernels::unary::strided::copy::U32, + DType::U8 => candle_metal_kernels::unary::strided::copy::U8, + dtype => crate::bail!("copy_strided not implemented for {dtype:?}"), + }; + candle_metal_kernels::call_unary_strided( + &self.device.device, + &command_buffer, + &self.device.kernels, + kernel_name, + src_l.dims(), + &self.buffer, + src_l.stride(), + src_l.start_offset() * self.dtype.size_in_bytes(), + &dst.buffer, + dst_offset * dst.dtype.size_in_bytes(), + ) + .map_err(MetalError::from)?; + command_buffer.set_label("copy_strided"); } - let command_buffer = self.device.command_queue.new_command_buffer(); - let kernel_name = match self.dtype { - DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, - DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, - DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, - dtype => crate::bail!("copy_strided not implemented for {dtype:?}"), - }; - candle_metal_kernels::call_unary_strided( - &self.device.device, - &command_buffer, - &self.device.kernels, - kernel_name, - src_l.dims(), - &self.buffer, - &src_l.stride(), - src_l.start_offset() * self.dtype.size_in_bytes(), - &mut dst.buffer, - dst_offset, - ) - .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); + drop(command_buffer); + self.device.commit(); Ok(()) } } impl MetalStorage { - pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self { + pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: DType) -> Self { + let matrices = Arc::new(RwLock::new(HashMap::new())); Self { buffer, device, dtype, + matrices, } } pub fn buffer(&self) -> &Buffer { &self.buffer } + + fn matrix( + &self, + (b, m, n): (NSUInteger, NSUInteger, NSUInteger), + transpose: bool, + size: NSUInteger, + offset: NSUInteger, + type_id: u32, + ) -> Result<Matrix> { + let key = (b, m, n, transpose, size, offset, type_id); + + let mut matrices = self.matrices.try_write().unwrap(); + if let Some(matrix) = matrices.get(&key) { + Ok(matrix.clone()) + } else { + let descriptor = if transpose { + MatrixDescriptor::init_multiple(n, m, b, m * size, m * n * size, type_id) + } else { + MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id) + }; + let matrix = Matrix::init_with_buffer_descriptor(&self.buffer, offset, &descriptor) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; + matrices.insert(key, matrix.clone()); + Ok(matrix) + } + } } impl BackendDevice for MetalDevice { @@ -720,10 +994,14 @@ impl BackendDevice for MetalDevice { let device = metal::Device::all().swap_remove(ordinal); let command_queue = device.new_command_queue(); + let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned())); let kernels = Arc::new(Kernels::new()); + let buffers = Arc::new(RwLock::new(HashMap::new())); Ok(Self { device, command_queue, + command_buffer, + buffers, kernels, }) } @@ -743,9 +1021,8 @@ impl BackendDevice for MetalDevice { } fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> { - // TODO Is there a faster way ? - let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?; - self.storage_from_cpu_storage(&cpu_storage) + let buffer = self.new_buffer(shape.elem_count(), dtype); + Ok(MetalStorage::new(buffer, self.clone(), dtype)) } fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> { @@ -755,49 +1032,20 @@ impl BackendDevice for MetalDevice { } fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> { - let option = metal::MTLResourceOptions::StorageModeManaged; let buffer = match storage { - CpuStorage::U8(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<u8>()) as NSUInteger, - option, - ), - CpuStorage::U32(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<u32>()) as NSUInteger, - option, - ), - CpuStorage::I64(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<i64>()) as NSUInteger, - option, - ), - CpuStorage::BF16(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<bf16>()) as NSUInteger, - option, - ), - CpuStorage::F16(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<f16>()) as NSUInteger, - option, - ), - CpuStorage::F32(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<f32>()) as NSUInteger, - option, - ), - CpuStorage::F64(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<f64>()) as NSUInteger, - option, - ), + CpuStorage::U8(storage) => self.new_buffer_with_data(storage), + CpuStorage::U32(storage) => self.new_buffer_with_data(storage), + CpuStorage::I64(storage) => self.new_buffer_with_data(storage), + CpuStorage::BF16(storage) => self.new_buffer_with_data(storage), + CpuStorage::F16(storage) => self.new_buffer_with_data(storage), + CpuStorage::F32(storage) => self.new_buffer_with_data(storage), + CpuStorage::F64(storage) => self.new_buffer_with_data(storage), }; - Ok(Self::Storage { - buffer, - device: self.clone(), - dtype: storage.dtype(), - }) + Ok(Self::Storage::new( + buffer.into(), + self.clone(), + storage.dtype(), + )) } fn rand_uniform( |