diff options
-rw-r--r-- | candle-core/src/metal_backend.rs | 702 | ||||
-rw-r--r-- | candle-examples/Cargo.toml | 1 | ||||
-rw-r--r-- | candle-metal-kernels/src/affine.metal | 18 | ||||
-rw-r--r-- | candle-metal-kernels/src/cast.metal | 18 | ||||
-rw-r--r-- | candle-metal-kernels/src/indexing.metal | 9 | ||||
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 303 | ||||
-rw-r--r-- | candle-metal-kernels/src/reduce.metal | 156 | ||||
-rw-r--r-- | candle-metal-kernels/src/ternary.metal | 3 | ||||
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 158 | ||||
-rw-r--r-- | candle-metal-kernels/src/unary.metal | 48 | ||||
-rw-r--r-- | candle-metal-kernels/tmp/affine.rs (renamed from candle-metal-kernels/examples/affine.rs) | 1 | ||||
-rw-r--r-- | candle-metal-kernels/tmp/binary.rs (renamed from candle-metal-kernels/examples/binary.rs) | 0 | ||||
-rw-r--r-- | candle-metal-kernels/tmp/cast.rs (renamed from candle-metal-kernels/examples/cast.rs) | 0 | ||||
-rw-r--r-- | candle-metal-kernels/tmp/unary.rs (renamed from candle-metal-kernels/examples/unary.rs) | 6 | ||||
-rw-r--r-- | candle-nn/Cargo.toml | 2 | ||||
-rw-r--r-- | candle-nn/src/ops.rs | 40 |
16 files changed, 988 insertions, 477 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 0b72f080..12f56d50 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -4,11 +4,13 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; use candle_metal_kernels::Kernels; -use core::mem; -use half::{bf16, f16}; +use half::f16; use metal; -use metal::{Buffer, CommandQueue, MTLResourceOptions, NSUInteger}; -use std::sync::Arc; +use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication}; +use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; +use std::collections::HashMap; +use std::path::Path; +use std::sync::{Arc, RwLock}; /// Metal related errors #[derive(thiserror::Error, Debug)] @@ -36,7 +38,9 @@ impl From<String> for MetalError { pub struct MetalDevice { device: metal::Device, command_queue: metal::CommandQueue, + command_buffer: Arc<RwLock<metal::CommandBuffer>>, kernels: Arc<candle_metal_kernels::Kernels>, + buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>, } impl std::fmt::Debug for MetalDevice { @@ -58,10 +62,48 @@ impl MetalDevice { self.registry_id() } + pub fn metal_device(&self) -> &metal::Device { + &self.device + } + pub fn command_queue(&self) -> &CommandQueue { &self.command_queue } + pub fn command_buffer(&self) -> std::sync::RwLockReadGuard<CommandBuffer> { + self.command_buffer.try_read().unwrap() + } + + pub fn commit(&self) { + let mut old = self.command_buffer.try_write().unwrap(); + match old.status() { + metal::MTLCommandBufferStatus::NotEnqueued + | metal::MTLCommandBufferStatus::Enqueued => { + old.commit(); + let command_buffer = self.command_queue.new_command_buffer().to_owned(); + *old = command_buffer; + } + _ => {} + } + } + + pub fn wait_until_completed(&self) { + let mut old = self.command_buffer.try_write().unwrap(); + match old.status() { + metal::MTLCommandBufferStatus::NotEnqueued + | metal::MTLCommandBufferStatus::Enqueued => { + old.commit(); + old.wait_until_completed(); + } + metal::MTLCommandBufferStatus::Committed | metal::MTLCommandBufferStatus::Scheduled => { + old.wait_until_completed(); + } + _ => {} + } + let command_buffer = self.command_queue.new_command_buffer().to_owned(); + *old = command_buffer; + } + pub fn kernels(&self) -> &Kernels { &self.kernels } @@ -70,16 +112,107 @@ impl MetalDevice { &self.device } - pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { + pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Arc<Buffer> { let size = (element_count * dtype.size_in_bytes()) as NSUInteger; - self.device - .new_buffer(size, MTLResourceOptions::StorageModeManaged) + self._new_buffer(size, MTLResourceOptions::StorageModePrivate) + } + + fn _new_buffer(&self, size: NSUInteger, option: MTLResourceOptions) -> Arc<Buffer> { + let mut buffers = self.buffers.try_write().unwrap(); + let subbuffers = buffers.entry((size, option)).or_insert(vec![]); + + for sub in &mut *subbuffers { + if Arc::strong_count(sub) == 1 { + return sub.clone(); + } + } + let new_buffer = self.device.new_buffer(size as NSUInteger, option); + let new_buffer = Arc::new(new_buffer); + subbuffers.push(new_buffer.clone()); + new_buffer + } + + pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> { + self._new_buffer(size, MTLResourceOptions::StorageModeManaged) + } + + pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> { + let size = core::mem::size_of_val(data) as NSUInteger; + let tmp = self.device.new_buffer_with_data( + data.as_ptr() as *const core::ffi::c_void, + size, + metal::MTLResourceOptions::StorageModeManaged, + ); + let real = self._new_buffer(size, metal::MTLResourceOptions::StorageModePrivate); + { + let command = self.command_buffer(); + let blit = command.new_blit_command_encoder(); + blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); + blit.end_encoding(); + } + // This is necessary, for mmaped safetensors + // Because of the unsafe slice cast we're doing. + // The slice might not live long enough for metal + // To actually fill the GPU buffer. + // Putting this wait forces the GPU buffer to be filled + // with the actual data allowing the CPU storage todo + // deallocate properly. + self.wait_until_completed(); + real + } + + pub fn new_matrix( + &self, + (b, m, n): (NSUInteger, NSUInteger, NSUInteger), + size: NSUInteger, + type_id: u32, + dtype: DType, + ) -> Result<(Matrix, Arc<Buffer>)> { + let elem_count = (b * m * n) as usize; + let out_buffer = self.new_buffer(elem_count, dtype); + + let result_descriptor = + MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id); + let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, 0, &result_descriptor) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; + Ok((result_matrix, out_buffer)) + } + + pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> { + let capture = metal::CaptureManager::shared(); + let descriptor = metal::CaptureDescriptor::new(); + descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); + descriptor.set_capture_device(&self); + descriptor.set_output_url(path); + + capture + .start_capture(&descriptor) + .map_err(MetalError::from)?; + Ok(()) } } #[derive(Debug, Clone)] pub struct MetalStorage { - buffer: metal::Buffer, + buffer: Arc<metal::Buffer>, + matrices: Arc< + RwLock< + HashMap< + ( + NSUInteger, + NSUInteger, + NSUInteger, + bool, + NSUInteger, + NSUInteger, + u32, + ), + Matrix, + >, + >, + >, device: MetalDevice, dtype: DType, } @@ -108,14 +241,23 @@ impl BackendStorage for MetalStorage { self.dtype ); } + + let buffer = self.device.new_buffer_managed(self.buffer.length()); + let command_buffer = self.device.command_buffer(); + let blit = command_buffer.new_blit_command_encoder(); + blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); + blit.end_encoding(); + drop(command_buffer); + self.device.wait_until_completed(); + match self.dtype { - DType::U8 => Ok(CpuStorage::U8(self.buffer.read_to_vec(length / size))), - DType::U32 => Ok(CpuStorage::U32(self.buffer.read_to_vec(length / size))), - DType::I64 => Ok(CpuStorage::I64(self.buffer.read_to_vec(length / size))), - DType::F16 => Ok(CpuStorage::F16(self.buffer.read_to_vec(length / size))), - DType::BF16 => Ok(CpuStorage::BF16(self.buffer.read_to_vec(length / size))), - DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(length / size))), - DType::F64 => Ok(CpuStorage::F64(self.buffer.read_to_vec(length / size))), + DType::U8 => Ok(CpuStorage::U8(buffer.read_to_vec(length / size))), + DType::U32 => Ok(CpuStorage::U32(buffer.read_to_vec(length / size))), + DType::I64 => Ok(CpuStorage::I64(buffer.read_to_vec(length / size))), + DType::F16 => Ok(CpuStorage::F16(buffer.read_to_vec(length / size))), + DType::BF16 => Ok(CpuStorage::BF16(buffer.read_to_vec(length / size))), + DType::F32 => Ok(CpuStorage::F32(buffer.read_to_vec(length / size))), + DType::F64 => Ok(CpuStorage::F64(buffer.read_to_vec(length / size))), } } @@ -126,30 +268,48 @@ impl BackendStorage for MetalStorage { let el = shape.elem_count(); let dtype = self.dtype; - if layout.is_contiguous() || layout.start_offset() != 0 || dtype != DType::F32 { - crate::bail!("Not contiguous, non-f32 affine is not implemented yet."); + let buffer = device.new_buffer(el, self.dtype); + let command_buffer = self.device.command_buffer(); + if layout.is_contiguous() && layout.start_offset() == 0 { + let name = match self.dtype { + DType::F32 => "affine_float", + DType::F16 => "affine_half", + dtype => crate::bail!("Affine {dtype:?}"), + }; + candle_metal_kernels::call_affine( + &device.device, + &command_buffer, + &device.kernels, + name, + el, + &self.buffer, + &buffer, + mul as f32, + add as f32, + ) + .map_err(MetalError::from)?; + } else { + let name = match self.dtype { + DType::F32 => "affine_float_strided", + DType::F16 => "affine_half_strided", + dtype => crate::bail!("Affine {dtype:?}"), + }; + candle_metal_kernels::call_affine_strided( + &device.device, + &command_buffer, + &device.kernels, + name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * dtype.size_in_bytes(), + &buffer, + mul as f32, + add as f32, + ) + .map_err(MetalError::from)?; } - - let mut buffer = device.new_buffer(el, self.dtype); - let command_buffer = self.device.command_queue.new_command_buffer(); - candle_metal_kernels::call_affine( - &device.device, - &command_buffer, - &device.kernels, - el, - &self.buffer, - &mut buffer, - mul as f32, - add as f32, - ) - .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); - return Ok(Self { - buffer, - device: device.clone(), - dtype, - }); + Ok(Self::new(buffer, device.clone(), dtype)) } fn powf(&self, _: &Layout, _: f64) -> Result<Self> { @@ -163,11 +323,11 @@ impl BackendStorage for MetalStorage { fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> { if !(sum_dims.len() == 1 && sum_dims[0] == layout.shape().rank() - 1 - && layout.is_contiguous() - && layout.start_offset() == 0) + && layout.stride()[sum_dims[0]] == 1) { - crate::bail!("Non contiguous reduce op not supported yet"); + crate::bail!("Non last dim reduce op not supported yet"); } + let device = self.device.clone(); let src_stride = layout.stride(); let src_dims = layout.shape().dims(); @@ -202,8 +362,11 @@ impl BackendStorage for MetalStorage { Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? } let dtype = if return_index { DType::U32 } else { self.dtype }; - let mut buffer = device.new_buffer(dst_el, dtype); - let command_buffer = self.device.command_queue.new_command_buffer(); + if dtype == DType::U32 { + crate::bail!("Implement return index reduce op"); + } + let buffer = device.new_buffer(dst_el, dtype); + let command_buffer = self.device.command_buffer(); candle_metal_kernels::call_reduce_contiguous( &device.device, &command_buffer, @@ -212,17 +375,12 @@ impl BackendStorage for MetalStorage { src_el, dst_el, &self.buffer, - &mut buffer, + layout.start_offset() * self.dtype.size_in_bytes(), + &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); - Ok(Self { - buffer, - device, - dtype, - }) + Ok(Self::new(buffer, device, dtype)) } fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> { @@ -233,11 +391,15 @@ impl BackendStorage for MetalStorage { let device = self.device(); let shape = layout.shape(); let el_count = shape.elem_count(); - let mut buffer = device.new_buffer(el_count, dtype); - let command_buffer = device.command_queue.new_command_buffer(); + let buffer = device.new_buffer(el_count, dtype); + let command_buffer = device.command_buffer(); if layout.is_contiguous() { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32", + (DType::U32, DType::U8) => "cast_u32_u8", + (DType::U8, DType::U32) => "cast_u8_u32", + (DType::F32, DType::F16) => "cast_f32_f16", + (DType::F16, DType::F32) => "cast_f16_f32", (left, right) => crate::bail!("to dtype {left:?} - {right:?}"), }; candle_metal_kernels::call_cast_contiguous( @@ -247,24 +409,34 @@ impl BackendStorage for MetalStorage { kernel_name, el_count, &self.buffer, - &mut buffer, + layout.start_offset() * self.dtype.size_in_bytes(), + &buffer, ) .map_err(MetalError::from)?; } else { - crate::bail!( - "TODO Implement the kernel calling cast {:?}-{:?}", - self.dtype, - dtype - ); + let kernel_name = match (self.dtype, dtype) { + (DType::U32, DType::F32) => "cast_u32_f32_strided", + (DType::U32, DType::U8) => "cast_u32_u8_strided", + (DType::U8, DType::U32) => "cast_u8_u32_strided", + (DType::F32, DType::F16) => "cast_f32_f16_strided", + (DType::F16, DType::F32) => "cast_f16_f32_strided", + (left, right) => crate::bail!("to dtype {left:?} - {right:?}"), + }; + candle_metal_kernels::call_cast_strided( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * self.dtype.size_in_bytes(), + &buffer, + ) + .map_err(MetalError::from)?; } - command_buffer.commit(); - command_buffer.wait_until_completed(); - Ok(Self { - buffer, - device: device.clone(), - dtype, - }) + Ok(Self::new(buffer, device.clone(), dtype)) } fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> { @@ -272,8 +444,8 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let shape = layout.shape(); let el_count = shape.elem_count(); - let mut buffer = device.new_buffer(el_count, dtype); - let command_buffer = device.command_queue.new_command_buffer(); + let buffer = device.new_buffer(el_count, dtype); + let command_buffer = device.command_buffer(); if layout.is_contiguous() && layout.start_offset() == 0 { use candle_metal_kernels::unary::contiguous; @@ -285,6 +457,25 @@ impl BackendStorage for MetalStorage { ("uneg", DType::F32) => contiguous::neg::FLOAT, ("uexp", DType::F32) => contiguous::exp::FLOAT, ("ulog", DType::F32) => contiguous::log::FLOAT, + ("ugelu", DType::F32) => contiguous::gelu::FLOAT, + ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT, + ("uerf", DType::F32) => contiguous::erf::FLOAT, + ("uceil", DType::F32) => contiguous::ceil::FLOAT, + ("ufloor", DType::F32) => contiguous::floor::FLOAT, + ("uround", DType::F32) => contiguous::round::FLOAT, + ("ucos", DType::F16) => contiguous::cos::HALF, + ("usin", DType::F16) => contiguous::sin::HALF, + ("usqr", DType::F16) => contiguous::sqr::HALF, + ("usqrt", DType::F16) => contiguous::sqrt::HALF, + ("uneg", DType::F16) => contiguous::neg::HALF, + ("uexp", DType::F16) => contiguous::exp::HALF, + ("ulog", DType::F16) => contiguous::log::HALF, + ("ugelu", DType::F16) => contiguous::gelu::HALF, + ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF, + ("uerf", DType::F16) => contiguous::erf::HALF, + ("uceil", DType::F16) => contiguous::ceil::HALF, + ("ufloor", DType::F16) => contiguous::floor::HALF, + ("uround", DType::F16) => contiguous::round::HALF, (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_unary_contiguous( @@ -294,20 +485,58 @@ impl BackendStorage for MetalStorage { kernel_name, el_count, &self.buffer, - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; } else { - crate::bail!("TODO Implement the kernel calling {}", B::KERNEL); + use candle_metal_kernels::unary::strided; + let kernel_name = match (B::KERNEL, dtype) { + ("ucos", DType::F32) => strided::cos::FLOAT, + ("usin", DType::F32) => strided::sin::FLOAT, + ("usqr", DType::F32) => strided::sqr::FLOAT, + ("usqrt", DType::F32) => strided::sqrt::FLOAT, + ("uneg", DType::F32) => strided::neg::FLOAT, + ("uexp", DType::F32) => strided::exp::FLOAT, + ("ulog", DType::F32) => strided::log::FLOAT, + ("ugelu", DType::F32) => strided::gelu::FLOAT, + ("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT, + ("uerf", DType::F32) => strided::erf::FLOAT, + ("uceil", DType::F32) => strided::ceil::FLOAT, + ("ufloor", DType::F32) => strided::floor::FLOAT, + ("uround", DType::F32) => strided::round::FLOAT, + ("ucos", DType::F16) => strided::cos::HALF, + ("usin", DType::F16) => strided::sin::HALF, + ("usqr", DType::F16) => strided::sqr::HALF, + ("usqrt", DType::F16) => strided::sqrt::HALF, + ("uneg", DType::F16) => strided::neg::HALF, + ("uexp", DType::F16) => strided::exp::HALF, + ("ulog", DType::F16) => strided::log::HALF, + ("ugelu", DType::F16) => strided::gelu::HALF, + ("ugelu_erf", DType::F16) => strided::gelu_erf::HALF, + ("uerf", DType::F16) => strided::erf::HALF, + ("uceil", DType::F16) => strided::ceil::HALF, + ("ufloor", DType::F16) => strided::floor::HALF, + ("uround", DType::F16) => strided::round::HALF, + (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), + }; + candle_metal_kernels::call_unary_strided( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * self.dtype.size_in_bytes(), + &buffer, + 0, + ) + .map_err(MetalError::from)?; } - command_buffer.commit(); - command_buffer.wait_until_completed(); - - Ok(Self { - buffer, - device: device.clone(), - dtype, - }) + command_buffer.set_label("unary"); + drop(command_buffer); + self.device.commit(); + Ok(Self::new(buffer, device.clone(), dtype)) } fn binary_impl<B: BinaryOpT>( @@ -320,8 +549,8 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let shape = lhs_l.shape(); let el_count = shape.elem_count(); - let mut buffer = device.new_buffer(el_count, dtype); - let command_buffer = device.command_queue.new_command_buffer(); + let buffer = device.new_buffer(el_count, dtype); + let command_buffer = device.command_buffer(); if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) { @@ -336,6 +565,14 @@ impl BackendStorage for MetalStorage { ("bmul", DType::F32) => contiguous::mul::FLOAT, ("div", DType::F32) => contiguous::div::FLOAT, ("bdiv", DType::F32) => contiguous::div::FLOAT, + ("add", DType::F16) => contiguous::add::HALF, + ("badd", DType::F16) => contiguous::add::HALF, + ("sub", DType::F16) => contiguous::sub::HALF, + ("bsub", DType::F16) => contiguous::sub::HALF, + ("mul", DType::F16) => contiguous::mul::HALF, + ("bmul", DType::F16) => contiguous::mul::HALF, + ("div", DType::F16) => contiguous::div::HALF, + ("bdiv", DType::F16) => contiguous::div::HALF, (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_binary_contiguous( @@ -346,7 +583,7 @@ impl BackendStorage for MetalStorage { el_count, &self.buffer, &rhs.buffer, - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; } else { @@ -357,6 +594,10 @@ impl BackendStorage for MetalStorage { ("bsub", DType::F32) => strided::sub::FLOAT, ("bmul", DType::F32) => strided::mul::FLOAT, ("bdiv", DType::F32) => strided::div::FLOAT, + ("badd", DType::F16) => strided::add::HALF, + ("bsub", DType::F16) => strided::sub::HALF, + ("bmul", DType::F16) => strided::mul::HALF, + ("bdiv", DType::F16) => strided::div::HALF, (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_binary_strided( @@ -366,23 +607,19 @@ impl BackendStorage for MetalStorage { kernel_name, lhs_l.dims(), &self.buffer, - &lhs_l.stride(), + lhs_l.stride(), lhs_l.start_offset() * self.dtype.size_in_bytes(), &rhs.buffer, - &rhs_l.stride(), + rhs_l.stride(), rhs_l.start_offset() * rhs.dtype.size_in_bytes(), - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; } - command_buffer.commit(); - command_buffer.wait_until_completed(); - - Ok(Self { - buffer, - device: device.clone(), - dtype, - }) + command_buffer.set_label("binary"); + drop(command_buffer); + self.device.commit(); + Ok(Self::new(buffer, device.clone(), dtype)) } fn where_cond( @@ -398,14 +635,22 @@ impl BackendStorage for MetalStorage { let dims = shape.dims(); let el = shape.elem_count(); let dtype = t.dtype; - let mut buffer = self.device.new_buffer(el, dtype); - let command_buffer = self.device.command_queue.new_command_buffer(); + let buffer = self.device.new_buffer(el, dtype); + let command_buffer = self.device.command_buffer(); + if t.dtype() != f.dtype() { + crate::bail!("Invalid ternary different dtypes for values"); + } + let name = match (self.dtype, t.dtype()) { + (DType::U8, DType::F32) => "where_u8_f32", + (DType::U8, DType::F16) => "where_u8_f16", + (left, right) => crate::bail!("Ternary {left:?} - {right:?} not implemented"), + }; candle_metal_kernels::call_where_cond_strided( &device.device, &command_buffer, &device.kernels, - "where_u8_f32", - &dims, + name, + dims, &self.buffer, ( layout.stride(), @@ -415,16 +660,10 @@ impl BackendStorage for MetalStorage { (&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()), &f.buffer, (&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()), - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); - Ok(Self { - buffer, - device, - dtype, - }) + Ok(Self::new(buffer, device, dtype)) } fn conv1d( @@ -513,12 +752,13 @@ impl BackendStorage for MetalStorage { let dst_el = ids_el * left_size * right_size; let dtype = self.dtype; let device = self.device(); - let mut buffer = device.new_buffer(dst_el, dtype); + let buffer = device.new_buffer(dst_el, dtype); let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "is_u32_f32", + (DType::U32, DType::F16) => "is_u32_f16", (left, right) => crate::bail!("index select metal {left:?} {right:?}"), }; - let command_buffer = self.device.command_queue.new_command_buffer(); + let command_buffer = self.device.command_buffer(); candle_metal_kernels::call_index_select( &device.device, &command_buffer, @@ -529,16 +769,10 @@ impl BackendStorage for MetalStorage { dim, &self.buffer, &ids.buffer, - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); - Ok(Self { - buffer, - device: device.clone(), - dtype, - }) + Ok(Self::new(buffer, device.clone(), dtype)) } fn index_add( @@ -561,11 +795,18 @@ impl BackendStorage for MetalStorage { rhs_l: &Layout, ) -> Result<Self> { // Create descriptors - use metal::mps::matrix::*; - let type_id = metal::mps::MPS_FLOATBIT_ENCODING | 32; - let size = core::mem::size_of::<f32>() as NSUInteger; - let elem_count = b * m * n; + let (type_id, size) = match self.dtype { + DType::F32 => ( + metal::mps::MPS_FLOATBIT_ENCODING | 32, + core::mem::size_of::<f32>() as NSUInteger, + ), + DType::F16 => ( + metal::mps::MPS_FLOATBIT_ENCODING | 16, + core::mem::size_of::<f16>() as NSUInteger, + ), + dtype => todo!("Dtype for matmul {dtype:?} is not supported"), + }; let lhs_stride = lhs_l.stride(); let rhs_stride = rhs_l.stride(); @@ -596,39 +837,30 @@ impl BackendStorage for MetalStorage { mnk: (m, n, k), })? }; - let b = b as NSUInteger; let m = m as NSUInteger; let n = n as NSUInteger; let k = k as NSUInteger; - let left_descriptor = if transpose_left { - MatrixDescriptor::init_single(k, m, m * size, type_id) - } else { - MatrixDescriptor::init_single(m, k, k * size, type_id) - }; - let right_descriptor = if transpose_right { - MatrixDescriptor::init_single(n, k, k * size, type_id) - } else { - MatrixDescriptor::init_single(k, n, n * size, type_id) - }; - let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id); - - // Create matrix objects - let left_matrix = Matrix::init_with_buffer_descriptor(&self.buffer, 0, &left_descriptor) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; - let right_matrix = Matrix::init_with_buffer_descriptor(&rhs.buffer, 0, &right_descriptor) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; + let left_matrix = self.matrix( + (b, m, k), + transpose_left, + size, + lhs_l.start_offset() as NSUInteger * size, + type_id, + )?; + let right_matrix = rhs.matrix( + (b, k, n), + transpose_right, + size, + rhs_l.start_offset() as NSUInteger * size, + type_id, + )?; + let (result_matrix, out_buffer) = + self.device + .new_matrix((b, m, n), size, type_id, self.dtype)?; - let out_buffer = self.device.new_buffer(elem_count, self.dtype); - let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, 0, &result_descriptor) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; + let command_buffer = self.device.command_buffer(); let alpha = 1.0f64; let beta = 0.0f64; @@ -647,70 +879,112 @@ impl BackendStorage for MetalStorage { MetalError::from("Failed to create matrix multiplication kernel".to_string()) })?; - matrix_multiplication.set_batch_size(b); - // Encode kernel to command buffer - let command_buffer = self.device.command_queue.new_command_buffer(); matrix_multiplication.encode_to_command_buffer( - command_buffer, + &command_buffer, &left_matrix, &right_matrix, &result_matrix, ); - command_buffer.commit(); - command_buffer.wait_until_completed(); + command_buffer.set_label("matmul"); + drop(command_buffer); + self.device.commit(); - Ok(Self { - buffer: out_buffer, - device: self.device.clone(), - dtype: self.dtype(), - }) + Ok(Self::new(out_buffer, self.device.clone(), self.dtype())) } fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { - let src_shape = src_l.shape(); - let el_count = src_shape.elem_count(); - if el_count == 0 { - return Ok(()); + let command_buffer = self.device.command_buffer(); + if src_l.is_contiguous() && self.dtype == dst.dtype() { + command_buffer.set_label("copy_contiguous"); + let blit = command_buffer.new_blit_command_encoder(); + let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger; + let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger; + blit.copy_from_buffer( + &self.buffer, + src_offset, + dst.buffer(), + dst_offset, + self.buffer.length() - src_offset, + ); + blit.end_encoding(); + } else { + let src_shape = src_l.shape(); + let el_count = src_shape.elem_count(); + if el_count == 0 { + return Ok(()); + } + let kernel_name = match self.dtype { + DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, + DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, + DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, + DType::U32 => candle_metal_kernels::unary::strided::copy::U32, + DType::U8 => candle_metal_kernels::unary::strided::copy::U8, + dtype => crate::bail!("copy_strided not implemented for {dtype:?}"), + }; + candle_metal_kernels::call_unary_strided( + &self.device.device, + &command_buffer, + &self.device.kernels, + kernel_name, + src_l.dims(), + &self.buffer, + src_l.stride(), + src_l.start_offset() * self.dtype.size_in_bytes(), + &dst.buffer, + dst_offset * dst.dtype.size_in_bytes(), + ) + .map_err(MetalError::from)?; + command_buffer.set_label("copy_strided"); } - let command_buffer = self.device.command_queue.new_command_buffer(); - let kernel_name = match self.dtype { - DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, - DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, - DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, - dtype => crate::bail!("copy_strided not implemented for {dtype:?}"), - }; - candle_metal_kernels::call_unary_strided( - &self.device.device, - &command_buffer, - &self.device.kernels, - kernel_name, - src_l.dims(), - &self.buffer, - &src_l.stride(), - src_l.start_offset() * self.dtype.size_in_bytes(), - &mut dst.buffer, - dst_offset, - ) - .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); + drop(command_buffer); + self.device.commit(); Ok(()) } } impl MetalStorage { - pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self { + pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: DType) -> Self { + let matrices = Arc::new(RwLock::new(HashMap::new())); Self { buffer, device, dtype, + matrices, } } pub fn buffer(&self) -> &Buffer { &self.buffer } + + fn matrix( + &self, + (b, m, n): (NSUInteger, NSUInteger, NSUInteger), + transpose: bool, + size: NSUInteger, + offset: NSUInteger, + type_id: u32, + ) -> Result<Matrix> { + let key = (b, m, n, transpose, size, offset, type_id); + + let mut matrices = self.matrices.try_write().unwrap(); + if let Some(matrix) = matrices.get(&key) { + Ok(matrix.clone()) + } else { + let descriptor = if transpose { + MatrixDescriptor::init_multiple(n, m, b, m * size, m * n * size, type_id) + } else { + MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id) + }; + let matrix = Matrix::init_with_buffer_descriptor(&self.buffer, offset, &descriptor) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; + matrices.insert(key, matrix.clone()); + Ok(matrix) + } + } } impl BackendDevice for MetalDevice { @@ -720,10 +994,14 @@ impl BackendDevice for MetalDevice { let device = metal::Device::all().swap_remove(ordinal); let command_queue = device.new_command_queue(); + let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned())); let kernels = Arc::new(Kernels::new()); + let buffers = Arc::new(RwLock::new(HashMap::new())); Ok(Self { device, command_queue, + command_buffer, + buffers, kernels, }) } @@ -743,9 +1021,8 @@ impl BackendDevice for MetalDevice { } fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> { - // TODO Is there a faster way ? - let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?; - self.storage_from_cpu_storage(&cpu_storage) + let buffer = self.new_buffer(shape.elem_count(), dtype); + Ok(MetalStorage::new(buffer, self.clone(), dtype)) } fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> { @@ -755,49 +1032,20 @@ impl BackendDevice for MetalDevice { } fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> { - let option = metal::MTLResourceOptions::StorageModeManaged; let buffer = match storage { - CpuStorage::U8(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<u8>()) as NSUInteger, - option, - ), - CpuStorage::U32(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<u32>()) as NSUInteger, - option, - ), - CpuStorage::I64(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<i64>()) as NSUInteger, - option, - ), - CpuStorage::BF16(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<bf16>()) as NSUInteger, - option, - ), - CpuStorage::F16(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<f16>()) as NSUInteger, - option, - ), - CpuStorage::F32(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<f32>()) as NSUInteger, - option, - ), - CpuStorage::F64(storage) => self.device.new_buffer_with_data( - storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::<f64>()) as NSUInteger, - option, - ), + CpuStorage::U8(storage) => self.new_buffer_with_data(storage), + CpuStorage::U32(storage) => self.new_buffer_with_data(storage), + CpuStorage::I64(storage) => self.new_buffer_with_data(storage), + CpuStorage::BF16(storage) => self.new_buffer_with_data(storage), + CpuStorage::F16(storage) => self.new_buffer_with_data(storage), + CpuStorage::F32(storage) => self.new_buffer_with_data(storage), + CpuStorage::F64(storage) => self.new_buffer_with_data(storage), }; - Ok(Self::Storage { - buffer, - device: self.clone(), - dtype: storage.dtype(), - }) + Ok(Self::Storage::new( + buffer.into(), + self.clone(), + storage.dtype(), + )) } fn rand_uniform( diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 38d26ead..adfa529e 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -57,6 +57,7 @@ flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] nccl = ["cuda", "cudarc/nccl", "dep:half"] onnx = ["candle-onnx"] +metal = ["candle/metal", "candle-nn/metal"] [[example]] name = "llama_multiprocess" diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index e5f0a841..a08bfbc0 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -33,6 +33,24 @@ kernel void FN_NAME( \ const TYPENAME a = TYPENAME(add); \ output[id] = input[id] * m + a; \ } \ +kernel void FN_NAME##_strided( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant float &mul, \ + constant float &add, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + const TYPENAME m = TYPENAME(mul); \ + const TYPENAME a = TYPENAME(add); \ + output[id] = input[get_strided_index(id, num_dims, dims, strides)] * m + a; \ +} \ AFFINE(affine_float, float) AFFINE(affine_half, half) diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index d1788253..4398e9d4 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -23,12 +23,12 @@ kernel void FN_NAME( \ constant size_t &dim, \ device const LEFT_TYPENAME *input, \ device RIGHT_TYPENAME *output, \ - uint thread_position_in_grid [[ thread_position_in_grid ]] \ + uint tid [[ thread_position_in_grid ]] \ ) { \ - if (thread_position_in_grid >= dim) { \ + if (tid >= dim) { \ return; \ } \ - output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \ + output[tid] = RIGHT_TYPENAME(input[tid]); \ } \ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -37,15 +37,19 @@ kernel void FN_NAME_STRIDED( \ constant size_t *strides, \ device const LEFT_TYPENAME *input, \ device RIGHT_TYPENAME *output, \ - uint i [[ thread_position_in_grid ]] \ + uint tid [[ thread_position_in_grid ]] \ ) { \ - if (i >= dim) { \ + if (tid >= dim) { \ return; \ } \ - output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \ + output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \ } \ -CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float) +CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float) +CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t) +CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t) +CAST(cast_f16_f32, cast_f16_f32_strided, half, float) +CAST(cast_f32_f16, cast_f32_f16_strided, float, half) #if __METAL_VERSION__ >= 310 #endif diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 444fa322..312b27c7 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -16,16 +16,16 @@ kernel void NAME( \ if (gid >= dst_size) { \ return; \ } \ - const size_t id_i = gid / right_size / left_size; \ + const size_t id_i = (gid / right_size) % ids_size; \ + const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \ const size_t right_rank_i = gid % right_size; \ - const size_t left_rank_i = gid % left_size; \ + const size_t left_rank_i = gid / right_size / ids_size; \ /* \ // Force prevent out of bounds indexing \ // since there doesn't seem to be a good way to force crash \ // No need to check for zero we're only allowing unsized. \ */ \ - const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \ - const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \ + const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \ output[gid] = input[src_i]; \ } @@ -75,6 +75,7 @@ kernel void FN_NAME( \ INDEX_OP(is_u32_f32, uint, float) +INDEX_OP(is_u32_f16, uint, half) #if __METAL_VERSION__ >= 310 diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 5a6bd41b..a0b852a4 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,6 +1,6 @@ use metal::{ - Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor, - ComputePipelineState, Device, Function, Library, MTLSize, + Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, + Device, Function, Library, MTLSize, }; use std::collections::HashMap; use std::ffi::c_void; @@ -59,8 +59,8 @@ impl<T> EncoderParam for &[T] { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { encoder.set_bytes( position, - (core::mem::size_of::<T>() * data.len()) as u64, - data.as_ptr() as *const T as *const c_void, + core::mem::size_of_val(data) as u64, + data.as_ptr() as *const c_void, ); } } @@ -111,13 +111,7 @@ macro_rules! ops{ ($($name:ident),+) => { pub mod contiguous { - #[derive(Clone, Copy)] - pub struct Kernel(pub(crate) &'static str); - impl std::fmt::Display for Kernel { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } - } + pub struct Kernel(pub &'static str); $( pub mod $name { use super::Kernel; @@ -126,16 +120,18 @@ macro_rules! ops{ pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat")); } )+ + pub mod copy { + use super::Kernel; + pub const FLOAT: Kernel = Kernel("copy_float"); + pub const HALF: Kernel = Kernel("copy_half"); + pub const BFLOAT: Kernel = Kernel("copy_bfloat"); + pub const U32: Kernel = Kernel("copy_u32"); + pub const U8: Kernel = Kernel("copy_u8"); + } } pub mod strided { - #[derive(Clone, Copy)] - pub struct Kernel(pub(crate) &'static str); - impl std::fmt::Display for Kernel { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } - } + pub struct Kernel(pub &'static str); $( pub mod $name { use super::Kernel; @@ -144,12 +140,20 @@ macro_rules! ops{ pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided")); } )+ + pub mod copy { + use super::Kernel; + pub const FLOAT: Kernel = Kernel("copy_float_strided"); + pub const HALF: Kernel = Kernel("copy_half_strided"); + pub const BFLOAT: Kernel = Kernel("copy_bfloat_strided"); + pub const U32: Kernel = Kernel("copy_u32_strided"); + pub const U8: Kernel = Kernel("copy_u8_strided"); + } } }; } pub mod unary { - ops!(cos, sin, exp, sqr, sqrt, neg, copy, log); + ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf); } pub mod binary { ops!(add, sub, mul, div); @@ -161,8 +165,12 @@ pub enum MetalKernelError { LockError(String), #[error("Error while loading library: {0}")] LoadLibraryError(String), - #[error("Error while loading function: {0}")] + #[error("Error while loading function: {0:?}")] LoadFunctionError(String), + #[error("Failed to create compute function")] + FailedToCreateComputeFunction, + #[error("Failed to create pipeline")] + FailedToCreatePipeline(String), } impl<T> From<std::sync::PoisonError<T>> for MetalKernelError { @@ -173,19 +181,22 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError { type KernelMap<T> = HashMap<&'static str, T>; type Libraries = HashMap<Source, Library>; -type Functions = KernelMap<Function>; +type Pipelines = KernelMap<ComputePipelineState>; #[derive(Debug, Default)] pub struct Kernels { libraries: RwLock<Libraries>, - funcs: RwLock<Functions>, + pipelines: RwLock<Pipelines>, } impl Kernels { pub fn new() -> Self { let libraries = RwLock::new(Libraries::new()); - let funcs = RwLock::new(Functions::new()); - Self { libraries, funcs } + let pipelines = RwLock::new(Pipelines::new()); + Self { + libraries, + pipelines, + } } fn get_library_source(&self, source: Source) -> &'static str { @@ -218,22 +229,43 @@ impl Kernels { } } - pub fn load_function( + fn load_function( &self, device: &Device, source: Source, name: &'static str, ) -> Result<Function, MetalKernelError> { - let mut funcs = self.funcs.write()?; - if let Some(func) = funcs.get(name) { - Ok(func.clone()) + let func = self + .load_library(device, source)? + .get_function(name, None) + .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; + Ok(func) + // let mut funcs = self.funcs.write()?; + // if let Some(func) = funcs.get(name) { + // Ok(func.clone()) + // } else { + // funcs.insert(name, func.clone()); + // Ok(func) + // } + } + + pub fn load_pipeline( + &self, + device: &Device, + source: Source, + name: &'static str, + ) -> Result<ComputePipelineState, MetalKernelError> { + let mut pipelines = self.pipelines.write()?; + if let Some(pipeline) = pipelines.get(name) { + Ok(pipeline.clone()) } else { - let func = self - .load_library(device, source)? - .get_function(name, None) - .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; - funcs.insert(name, func.clone()); - Ok(func) + let func = self.load_function(device, source, name)?; + let pipeline = device + .new_compute_pipeline_state_with_function(&func) + .map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?; + pipelines.insert(name, pipeline.clone()); + + Ok(pipeline) } } } @@ -246,18 +278,9 @@ pub fn call_unary_contiguous( kernel_name: unary::contiguous::Kernel, length: usize, input: &Buffer, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Unary, kernel_name.0)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); - + let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -279,18 +302,10 @@ pub fn call_unary_strided( input: &Buffer, strides: &[usize], offset: usize, - output: &mut Buffer, + output: &Buffer, output_offset: usize, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Unary, name.0)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); @@ -326,17 +341,9 @@ pub fn call_binary_contiguous( length: usize, left: &Buffer, right: &Buffer, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Binary, kernel_name.0)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -363,17 +370,9 @@ pub fn call_binary_strided( right_input: &Buffer, right_strides: &[usize], right_offset: usize, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Binary, name.0)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?; let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); @@ -411,22 +410,52 @@ pub fn call_cast_contiguous( kernel_name: &'static str, length: usize, input: &Buffer, - output: &mut Buffer, + input_offset: usize, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Cast, kernel_name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); + let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, (input, input_offset), output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_cast_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + shape: &[usize], + input: &Buffer, + input_strides: &[usize], + input_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, input, output)); + let length: usize = shape.iter().product(); + + set_params!( + encoder, + ( + length, + shape.len(), + shape, + input_strides, + (input, input_offset), + output + ) + ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); @@ -435,7 +464,6 @@ pub fn call_cast_contiguous( Ok(()) } -#[allow(clippy::too_many_arguments)] pub fn call_reduce_contiguous( device: &Device, command_buffer: &CommandBufferRef, @@ -444,24 +472,19 @@ pub fn call_reduce_contiguous( length: usize, out_length: usize, input: &Buffer, - output: &mut Buffer, + input_offset: usize, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Reduce, kernel_name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); - + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let elements_to_sum = length / out_length; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, elements_to_sum, input, output)); + set_params!( + encoder, + (length, elements_to_sum, (input, input_offset), output) + ); let thread_group_count = MTLSize { width: out_length as u64, @@ -495,18 +518,9 @@ pub fn call_last_softmax( length: usize, elements_to_sum: usize, input: &Buffer, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Reduce, kernel_name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); - + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -542,21 +556,14 @@ pub fn call_affine( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, + name: &'static str, size: usize, input: &Buffer, - output: &mut Buffer, + output: &Buffer, mul: f32, add: f32, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Affine, "affine_float")?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -570,6 +577,45 @@ pub fn call_affine( } #[allow(clippy::too_many_arguments)] +pub fn call_affine_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + input: &Buffer, + input_stride: &[usize], + input_offset: usize, + output: &Buffer, + mul: f32, + add: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + let size: usize = shape.iter().product(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + size, + shape.len(), + shape, + input_stride, + mul, + add, + (input, input_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + pub fn call_where_cond_strided( device: &Device, command_buffer: &CommandBufferRef, @@ -582,17 +628,9 @@ pub fn call_where_cond_strided( (left_stride, left_offset): (&[usize], usize), right: &Buffer, (right_stride, right_offset): (&[usize], usize), - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Ternary, name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -634,17 +672,14 @@ pub fn call_index_select( dim: usize, input: &Buffer, ids: &Buffer, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = shape[..dim].iter().product(); let right_size: usize = shape[dim + 1..].iter().product(); let src_dim_size = shape[dim]; let dst_el = ids_size * left_size * right_size; - let func = kernels.load_function(device, Source::Indexing, name)?; - let pipeline = device - .new_compute_pipeline_state_with_function(&func) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = command_buffer.new_compute_command_encoder(); diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index c6984474..867877fb 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -1,6 +1,8 @@ #include <metal_stdlib> using namespace metal; +#define MAX(x, y) ((x) > (y) ? (x) : (y)) + METAL_FUNC uint get_strided_index( uint idx, constant size_t &num_dims, @@ -16,18 +18,18 @@ METAL_FUNC uint get_strided_index( return strided_i; } -constant int THREADGROUP_SIZE = 256; +constant int THREADGROUP_SIZE = 1024; -# define REDUCE(FN, NAME, TYPENAME) \ +# define REDUCE(FN, NAME, T) \ kernel void NAME( \ constant size_t &src_numel, \ constant size_t &el_to_sum_per_block, \ - device const TYPENAME *src, \ - device TYPENAME *dst, \ + device const T *src, \ + device T *dst, \ uint id [[ thread_position_in_grid ]], \ uint tid [[ thread_index_in_threadgroup ]], \ uint dst_id [[ threadgroup_position_in_grid ]], \ - uint blockDim [[ threads_per_threadgroup ]] \ + uint block_dim [[ threads_per_threadgroup ]] \ ) { \ \ threadgroup float shared_memory[THREADGROUP_SIZE]; \ @@ -45,10 +47,10 @@ kernel void NAME( \ // TODO: Fast version for the contiguous case. \ // size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ */ \ - TYPENAME x = shared_memory[tid]; \ - TYPENAME y = src[idx]; \ + T x = shared_memory[tid]; \ + T y = src[idx]; \ shared_memory[tid] = FN; \ - idx += blockDim; \ + idx += block_dim; \ } \ \ threadgroup_barrier(mem_flags::mem_none); \ @@ -56,10 +58,10 @@ kernel void NAME( \ /* \ // reduction in shared memory \ */ \ - for (uint s = blockDim / 2; s > 0; s >>= 1) { \ + for (uint s = block_dim / 2; s > 0; s >>= 1) { \ if (tid < s) { \ - TYPENAME x = shared_memory[tid]; \ - TYPENAME y = shared_memory[tid + s]; \ + T x = shared_memory[tid]; \ + T y = shared_memory[tid + s]; \ shared_memory[tid] = FN; \ } \ threadgroup_barrier(mem_flags::mem_none); \ @@ -68,72 +70,74 @@ kernel void NAME( \ dst[dst_id] = shared_memory[0]; \ } \ -kernel void softmax_float( - constant size_t &src_numel, - constant size_t &el_to_sum_per_block, - device const float *src, - device float *dst, - uint id [[ thread_position_in_grid ]], - uint tid [[ thread_index_in_threadgroup ]], - uint dst_id [[ threadgroup_position_in_grid ]], - uint blockDim [[ threads_per_threadgroup ]] -) { - - threadgroup float shared_memory[THREADGROUP_SIZE]; - - shared_memory[tid] = -INFINITY; - // Elements summed in this block range from dst_id * el_to_sum_per_block - // to (dst_id + 1) * el_to_sum_per_block. - size_t start_idx = dst_id * el_to_sum_per_block; - size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); - size_t idx = start_idx + tid; - - while (idx < stop_idx) { - // TODO: Fast version for the contiguous case. - shared_memory[tid] = max(shared_memory[tid], src[idx]); - idx += blockDim; - } - - threadgroup_barrier(mem_flags::mem_none); - - // reduction in shared memory - for (uint s = blockDim / 2; s > 0; s >>= 1) { - if (tid < s) { - shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]); - } - threadgroup_barrier(mem_flags::mem_none); - } - - float max = shared_memory[0]; - - shared_memory[tid] = 0; - - // Restart - idx = start_idx + tid; - while (idx < stop_idx) { - // TODO: Fast version for the contiguous case. - const float val = exp(src[idx] - max); - dst[idx] = val; - shared_memory[tid] += val; - idx += blockDim; - } - // reduction in shared memory - for (uint s = blockDim / 2; s > 0; s >>= 1) { - if (tid < s) { - shared_memory[tid] += shared_memory[tid + s]; - } - threadgroup_barrier(mem_flags::mem_none); - } - - const float inv_acc = 1/shared_memory[0]; - idx = start_idx + tid; - while (idx < stop_idx) { - dst[idx] *= inv_acc; - idx += blockDim; - } -} - REDUCE(x + y, fast_sum_float, float) REDUCE(x * y, fast_mul_float, float) REDUCE(max(x, y), fast_max_float, float) + +#define SOFTMAX(NAME, T) \ +kernel void NAME( \ + constant size_t &src_numel, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ + \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + threadgroup float shared_memory[THREADGROUP_SIZE]; \ + shared_memory[tid] = -INFINITY; \ + size_t start_idx = dst_id * el_to_sum_per_block; \ + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \ + size_t idx = start_idx + tid; \ + \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + \ + while (idx < stop_idx) { \ + shared_memory[tid] = MAX(shared_memory[tid], src[idx]); \ + idx += block_dim; \ + } \ + \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + \ + for (uint s = block_dim / 2; s > 0; s >>= 1) { \ + if (tid < s) { \ + shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \ + } \ + } \ + \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + \ + float _max = shared_memory[0]; \ + \ + shared_memory[tid] = 0; \ + \ + idx = start_idx + tid; \ + while (idx < stop_idx) { \ + const T val = T(exp(src[idx] - _max)); \ + dst[idx] = val; \ + shared_memory[tid] += val; \ + idx += block_dim; \ + } \ + for (uint s = block_dim / 2; s > 0; s >>= 1) { \ + if (tid < s) { \ + shared_memory[tid] += shared_memory[tid + s]; \ + } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + } \ + \ + const T inv_acc = T(1/shared_memory[0]); \ + idx = start_idx + tid; \ + while (idx < stop_idx) { \ + dst[idx] *= inv_acc; \ + idx += block_dim; \ + } \ +} \ + +SOFTMAX(softmax_float, float) +SOFTMAX(softmax_half, half) +#if __METAL_VERSION__ >= 310 +SOFTMAX(softmax_bfloat, bfloat) +#endif diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal index 0945b355..1f9cb38a 100644 --- a/candle-metal-kernels/src/ternary.metal +++ b/candle-metal-kernels/src/ternary.metal @@ -32,6 +32,9 @@ kernel void FN_NAME( \ device TYPENAME *out ,\ uint i [[ thread_position_in_grid ]] \ ) { \ + if (i >= numel){ \ + return; \ + } \ uint strided_i = get_strided_index(i, num_dims, dims, strides); \ uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \ uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \ diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 2330d48d..66dc8d01 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1,5 +1,5 @@ use super::*; -use half::f16; +use half::{bf16, f16}; use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger}; fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer { @@ -23,13 +23,18 @@ fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> { v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect() } +fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> { + let b = 10f32.powi(digits); + v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect() +} + fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> { let device = device(); let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let output = new_buffer(&device, v); call_unary_contiguous( &device, command_buffer, @@ -37,7 +42,7 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> { name, v.len(), &input, - &mut output, + &output, ) .unwrap(); command_buffer.commit(); @@ -53,7 +58,7 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V let options = MTLResourceOptions::StorageModeManaged; let left = new_buffer(&device, x); let right = new_buffer(&device, y); - let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options); + let output = device.new_buffer(std::mem::size_of_val(x) as u64, options); call_binary_contiguous( &device, command_buffer, @@ -62,7 +67,7 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V x.len(), &left, &right, - &mut output, + &output, ) .unwrap(); command_buffer.commit(); @@ -81,7 +86,7 @@ fn run_strided<T: Clone>( let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let output = new_buffer(&device, v); let kernels = Kernels::new(); call_unary_strided( &device, @@ -92,7 +97,7 @@ fn run_strided<T: Clone>( &input, strides, offset, - &mut output, + &output, 0, ) .unwrap(); @@ -220,7 +225,9 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> { let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let options = MTLResourceOptions::StorageModeManaged; + let size = (v.len() * std::mem::size_of::<U>()) as u64; + let output = device.new_buffer(size, options); call_cast_contiguous( &device, @@ -229,7 +236,8 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> { name, v.len(), &input, - &mut output, + 0, + &output, ) .unwrap(); command_buffer.commit(); @@ -245,11 +253,17 @@ fn cast_u32_f32() { assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]); assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]); + let v = vec![1.0f32, 2.0, 3.0]; + let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect(); + let results: Vec<f32> = cast(&input, "cast_f16_f32"); + assert_eq!(results, vec![1.0f32, 2.0, 3.0]); + let v = vec![1.0f32; 10_000]; - let results = run(&v, unary::contiguous::cos::FLOAT); - let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); - assert_eq!(approx(results, 4), vec![0.5403; 10_000]); - assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); + let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect(); + let results: Vec<f32> = cast(&input, "cast_f16_f32"); + assert_eq!(results.len(), 10_000); + assert_eq!(&results[..10], vec![1.0f32; 10]); + assert_eq!(results, vec![1.0f32; 10_000]); } fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> { @@ -259,7 +273,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> { let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let output = new_buffer(&device, v); let size = v.len(); @@ -267,9 +281,45 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> { &device, command_buffer, &kernels, + "affine_float", size, &input, - &mut output, + &output, + mul as f32, + add as f32, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::<T>(v.len()) +} + +fn _run_affine_strided<T: Clone>( + v: &[T], + shape: &[usize], + strides: &[usize], + mul: f64, + add: f64, +) -> Vec<T> { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let input = new_buffer(&device, v); + let output = new_buffer(&device, v); + + call_affine_strided( + &device, + command_buffer, + &kernels, + "affine_float", + shape, + &input, + strides, + 0, + &output, mul as f32, add as f32, ) @@ -295,6 +345,16 @@ fn affine() { assert_eq!(result, vec![2.6; 40_000]); } +// #[test] +// fn affine_strided() { +// let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; +// let mul = 1.5; +// let add = 1.1; +// let result = run_affine_(&input, mul, add); +// assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]); + +// } + #[test] fn index_select() { let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; @@ -313,7 +373,26 @@ fn index_select() { result, vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] ); +} + +#[test] +fn index_select_f16() { + let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + .into_iter() + .map(|x| f16::from_f32(x)) + .collect(); + let shape = [5, 2]; + let ids = [0u32, 4, 2]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim); + assert_eq!( + approx_f16(result, 4), + vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] + ); +} +#[test] +fn index_select_dim1() { let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let shape = [5, 2]; let ids = [0u32, 1, 0]; @@ -321,7 +400,7 @@ fn index_select() { let result = run_index_select(&embedding, &shape, &ids, dim); assert_eq!( result, - vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] + vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0] ); } @@ -341,20 +420,26 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>( let left_size: usize = shape[..dim].iter().product(); let right_size: usize = shape[dim + 1..].iter().product(); let dst_el = ids.len() * left_size * right_size; - let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); + let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); + + let name = match core::mem::size_of::<T>() { + 4 => "is_u32_f32", + 2 => "is_u32_f16", + _ => unimplemented!(), + }; let kernels = Kernels::new(); call_index_select( &device, &command_buffer, &kernels, - "is_u32_f32", + name, shape, ids.len(), dim, &embeddings_buffer, &ids_buffer, - &mut dst_buffer, + &dst_buffer, ) .unwrap(); @@ -451,7 +536,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T let input = new_buffer(&device, v); let options = MTLResourceOptions::StorageModeManaged; - let mut output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options); + let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options); call_reduce_contiguous( &device, command_buffer, @@ -460,7 +545,8 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T v.len(), out_length, &input, - &mut output, + 0, + &output, ) .unwrap(); command_buffer.commit(); @@ -475,7 +561,7 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let output = new_buffer(&device, v); call_last_softmax( &device, command_buffer, @@ -484,7 +570,7 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta v.len(), last_dim, &input, - &mut output, + &output, ) .unwrap(); command_buffer.commit(); @@ -536,6 +622,28 @@ fn softmax() { approx(results, 4), vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652] ); + + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect::<Vec<_>>(); + let last_dim = 6; + let results = run_softmax(&v, last_dim, "softmax_half"); + assert_eq!( + approx_f16(results, 4), + vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338] + ); + + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] + .iter() + .map(|v| bf16::from_f32(*v)) + .collect::<Vec<_>>(); + let last_dim = 6; + let results = run_softmax(&v, last_dim, "softmax_bfloat"); + assert_eq!( + approx_bf16(results, 4), + vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328] + ); } fn run_where_cond<I: Clone, T: Clone>( @@ -571,7 +679,7 @@ fn run_where_cond<I: Clone, T: Clone>( options, ); - let mut output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options); + let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options); call_where_cond_strided( &device, command_buffer, @@ -584,7 +692,7 @@ fn run_where_cond<I: Clone, T: Clone>( (&left_stride, left_offset), &right, (&cond_stride, cond_offset), - &mut output, + &output, ) .unwrap(); command_buffer.commit(); diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index eb6424e8..88139af9 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -1,4 +1,7 @@ #include <metal_stdlib> +#include <metal_math> +# +using namespace metal; METAL_FUNC uint get_strided_index( uint idx, @@ -17,10 +20,39 @@ METAL_FUNC uint get_strided_index( template <typename T> METAL_FUNC T sqr(T in){ return in * in; } template <typename T> METAL_FUNC T neg(T in){ return -in; } +template <typename T> METAL_FUNC T erf(T in){ + float x = (float) in; + // constants + float a1 = 0.254829592; + float a2 = -0.284496736; + float a3 = 1.421413741; + float a4 = -1.453152027; + float a5 = 1.061405429; + float p = 0.3275911; + + // Save the sign of x + int sign = 1; + if (x < 0) + sign = -1; + x = fabs(x); + + // A&S formula 7.1.26 + float t = 1.0/(1.0 + p*x); + float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x); + + return T(sign*y); +} template <typename T> METAL_FUNC T id(T in){ return in; } +template <typename T> METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); } +template <typename T> METAL_FUNC T gelu(T x){ + T x_sq = x * x; + T x_cube = x_sq * x; + T alpha = x + static_cast<T>(0.044715) * x_cube; + T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha); + return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta))); +} -using namespace metal; #define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ kernel void FN_NAME( \ @@ -64,8 +96,16 @@ UNARY_OP(sqrt) UNARY_OP(neg) UNARY_OP(exp) UNARY_OP(log) +UNARY_OP(gelu) +UNARY_OP(ceil) +UNARY_OP(floor) +UNARY_OP(round) +UNARY_OP(gelu_erf) +UNARY_OP(erf) UNARY(id, float, copy_float, copy_float_strided) UNARY(id, half, copy_half, copy_half_strided) +UNARY(id, uint8_t, copy_u8, copy_u8_strided) +UNARY(id, uint32_t, copy_u32, copy_u32_strided) #if __METAL_VERSION__ >= 310 BFLOAT_UNARY_OP(cos) @@ -75,6 +115,12 @@ BFLOAT_UNARY_OP(sqrt) BFLOAT_UNARY_OP(neg) BFLOAT_UNARY_OP(exp) BFLOAT_UNARY_OP(log) +BFLOAT_UNARY_OP(gelu) +BFLOAT_UNARY_OP(ceil) +BFLOAT_UNARY_OP(floor) +BFLOAT_UNARY_OP(round) +BFLOAT_UNARY_OP(gelu_erf) +BFLOAT_UNARY_OP(erf) UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided) #endif diff --git a/candle-metal-kernels/examples/affine.rs b/candle-metal-kernels/tmp/affine.rs index b8005dc0..cd019056 100644 --- a/candle-metal-kernels/examples/affine.rs +++ b/candle-metal-kernels/tmp/affine.rs @@ -50,6 +50,7 @@ fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) { &device, command_buffer, &kernels, + "affine_float", v.len(), &input, &mut output, diff --git a/candle-metal-kernels/examples/binary.rs b/candle-metal-kernels/tmp/binary.rs index af5a8bdc..af5a8bdc 100644 --- a/candle-metal-kernels/examples/binary.rs +++ b/candle-metal-kernels/tmp/binary.rs diff --git a/candle-metal-kernels/examples/cast.rs b/candle-metal-kernels/tmp/cast.rs index 090f510d..090f510d 100644 --- a/candle-metal-kernels/examples/cast.rs +++ b/candle-metal-kernels/tmp/cast.rs diff --git a/candle-metal-kernels/examples/unary.rs b/candle-metal-kernels/tmp/unary.rs index 7039c098..66cf25c0 100644 --- a/candle-metal-kernels/examples/unary.rs +++ b/candle-metal-kernels/tmp/unary.rs @@ -147,7 +147,7 @@ fn run_unary_bench<T: Clone>( println!( "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", type_name::<T>().split("::").last().unwrap(), - kernel_name.to_string(), + kernel_name.0, v.len(), iterations, total_time, @@ -159,7 +159,7 @@ fn run_unary_bench<T: Clone>( let shape = vec![2, 5_000]; let strides = vec![2, 1]; let offset = 0; - for kernel_name in strided { + for kernel_name in &strided { let total_time = autoreleasepool(|| { let command_buffer = command_queue.new_command_buffer(); let start = Instant::now(); @@ -187,7 +187,7 @@ fn run_unary_bench<T: Clone>( println!( "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", type_name::<T>().split("::").last().unwrap(), - kernel_name.to_string(), + kernel_name.0, v.len(), iterations, total_time, diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index d3f43c73..45298907 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -19,6 +19,7 @@ num-traits = { workspace = true } rayon = { workspace = true } safetensors = { workspace = true } serde = { workspace = true } +candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true } [dev-dependencies] anyhow = { workspace = true } @@ -29,3 +30,4 @@ default = [] accelerate = ["dep:accelerate-src", "candle/accelerate"] cuda = ["candle/cuda"] mkl = ["dep:intel-mkl-src", "candle/mkl"] +metal = ["candle/metal", "dep:candle-metal-kernels"] diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index a0269e59..350bc663 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -201,6 +201,46 @@ impl candle::CustomOp1 for SoftmaxLastDim { }; Ok((dst, layout.shape().clone())) } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + storage: &candle::MetalStorage, + layout: &Layout, + ) -> Result<(candle::MetalStorage, Shape)> { + use candle::{backend::BackendStorage, DType}; + let device = storage.device(); + let command_buffer = device.command_buffer(); + let kernels = device.kernels(); + let name = match storage.dtype() { + DType::F32 => "softmax_float", + DType::F16 => "softmax_half", + DType::BF16 => "softmax_bfloat", + dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"), + }; + + let n = layout.stride().len(); + if !(layout.stride()[n - 1] == 1 && layout.start_offset() == 0) { + candle::bail!("Non contiguous softmax-last-dim is not implemented"); + } + + let last_dim = layout.dims()[layout.shape().rank() - 1]; + let elem_count = layout.shape().elem_count(); + let mut output = device.new_buffer(elem_count, storage.dtype()); + candle_metal_kernels::call_last_softmax( + device.metal_device(), + &command_buffer, + &kernels, + name, + elem_count, + last_dim, + storage.buffer(), + &mut output, + ) + .unwrap(); + let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype()); + Ok((newstorage, layout.shape().clone())) + } } pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> { |