From 4349ff1fc29a1a25b2ccdf56fbf68a98f5364c0a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Sat, 11 Nov 2023 01:02:15 +0100 Subject: Starting to fix some tests. Few fixes. Going back on remote metal-rs. Reusing a single buffer (for now) to speed things up. Adding some half kernels. All tests are panicking instead of random failure. Putting back f16 index select. Add erf. Working version for llama2-c. Fixes + cache compute_pipeline_state. BF16 metal fix. Remove some prints. new_owned -> new()..to_owned(). Better batched matmul. Metal operational. Reuse buffers on our own reference counts. Tmp gemm. Revert "Tmp gemm." This reverts commit c65f68e98814b65daa596696bda076a73303dd82. Interleave committing. Speeding up copies using blit. Fmt. Fmt. Remove the assert! Fmt all. Fixes after big rebase. Add softmax for half and bfloat + tests Fixing Llama example + accumulate softmax in float. --- candle-nn/Cargo.toml | 2 ++ candle-nn/src/ops.rs | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) (limited to 'candle-nn') diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index d3f43c73..45298907 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -19,6 +19,7 @@ num-traits = { workspace = true } rayon = { workspace = true } safetensors = { workspace = true } serde = { workspace = true } +candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true } [dev-dependencies] anyhow = { workspace = true } @@ -29,3 +30,4 @@ default = [] accelerate = ["dep:accelerate-src", "candle/accelerate"] cuda = ["candle/cuda"] mkl = ["dep:intel-mkl-src", "candle/mkl"] +metal = ["candle/metal", "dep:candle-metal-kernels"] diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index a0269e59..350bc663 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -201,6 +201,46 @@ impl candle::CustomOp1 for SoftmaxLastDim { }; Ok((dst, layout.shape().clone())) } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + storage: &candle::MetalStorage, + layout: &Layout, + ) -> Result<(candle::MetalStorage, Shape)> { + use candle::{backend::BackendStorage, DType}; + let device = storage.device(); + let command_buffer = device.command_buffer(); + let kernels = device.kernels(); + let name = match storage.dtype() { + DType::F32 => "softmax_float", + DType::F16 => "softmax_half", + DType::BF16 => "softmax_bfloat", + dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"), + }; + + let n = layout.stride().len(); + if !(layout.stride()[n - 1] == 1 && layout.start_offset() == 0) { + candle::bail!("Non contiguous softmax-last-dim is not implemented"); + } + + let last_dim = layout.dims()[layout.shape().rank() - 1]; + let elem_count = layout.shape().elem_count(); + let mut output = device.new_buffer(elem_count, storage.dtype()); + candle_metal_kernels::call_last_softmax( + device.metal_device(), + &command_buffer, + &kernels, + name, + elem_count, + last_dim, + storage.buffer(), + &mut output, + ) + .unwrap(); + let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype()); + Ok((newstorage, layout.shape().clone())) + } } pub fn softmax_last_dim(xs: &Tensor) -> Result { -- cgit v1.2.3 From 87dc559817db11f8d8c409cda959528e57e1db31 Mon Sep 17 00:00:00 2001 From: nicolas Date: Tue, 12 Dec 2023 17:41:56 +0100 Subject: Lots of updates including some stack of command buffers. --- candle-core/src/metal_backend.rs | 389 +++++++++++++++++++++------- candle-core/src/tensor.rs | 2 +- candle-metal-kernels/src/affine.metal | 75 +++++- candle-metal-kernels/src/lib.rs | 126 ++++++++- candle-metal-kernels/src/reduce.metal | 2 +- candle-metal-kernels/src/unary.metal | 6 +- candle-nn/Cargo.toml | 3 +- candle-nn/src/ops.rs | 4 +- candle-transformers/Cargo.toml | 1 + candle-transformers/src/models/mixformer.rs | 46 +++- 10 files changed, 537 insertions(+), 117 deletions(-) (limited to 'candle-nn') diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 12f56d50..4354422c 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -38,7 +38,8 @@ impl From for MetalError { pub struct MetalDevice { device: metal::Device, command_queue: metal::CommandQueue, - command_buffer: Arc>, + command_buffers: Arc>>, + command_buffer_index: Arc>, kernels: Arc, buffers: Arc>>>>, } @@ -70,38 +71,69 @@ impl MetalDevice { &self.command_queue } - pub fn command_buffer(&self) -> std::sync::RwLockReadGuard { - self.command_buffer.try_read().unwrap() - } - - pub fn commit(&self) { - let mut old = self.command_buffer.try_write().unwrap(); - match old.status() { - metal::MTLCommandBufferStatus::NotEnqueued - | metal::MTLCommandBufferStatus::Enqueued => { - old.commit(); - let command_buffer = self.command_queue.new_command_buffer().to_owned(); - *old = command_buffer; + pub fn command_buffer(&self) -> CommandBuffer { + let mut command_buffers = self.command_buffers.try_write().unwrap(); + let mut index = self.command_buffer_index.try_write().unwrap(); + let n = command_buffers.len(); + if *index == n { + // todo!("Cycle buffers"); + for i in 0..n { + let command_buffer = &command_buffers[i]; + match command_buffer.status() { + metal::MTLCommandBufferStatus::Committed + | metal::MTLCommandBufferStatus::Scheduled => { + // println!("Wait during cycling {i}"); + // println!("Command {i} / {n}: {:?}", command_buffer.status()); + command_buffer.wait_until_completed(); + } + metal::MTLCommandBufferStatus::Completed => {} + _ => { + panic!("Command buffer {i} not committed during cycling"); + } + } } - _ => {} + let new_buffers = (0..n) + .map(|i| { + // println!("Creating command buffer {i}"); + let command_buffer = self.command_queue.new_command_buffer().to_owned(); + command_buffer.enqueue(); + command_buffer + }) + .collect(); + *command_buffers = new_buffers; + *index = 0; + // println!("Reset"); } + // println!("Giving buffer {} / {n}", *index); + let out = &command_buffers[*index]; + assert_eq!(out.status(), metal::MTLCommandBufferStatus::Enqueued); + *index += 1; + out.to_owned() } pub fn wait_until_completed(&self) { - let mut old = self.command_buffer.try_write().unwrap(); - match old.status() { - metal::MTLCommandBufferStatus::NotEnqueued - | metal::MTLCommandBufferStatus::Enqueued => { - old.commit(); - old.wait_until_completed(); + let command_buffers = self.command_buffers.try_write().unwrap(); + let index = self.command_buffer_index.try_write().unwrap(); + let n = command_buffers.len(); + // for i in 0..*index { + // let command_buffer = &command_buffers[i]; + // println!("Command {i} / {n}: {:?}", command_buffer.status()); + // } + for i in 0..*index { + let command_buffer = &command_buffers[i]; + match command_buffer.status() { + metal::MTLCommandBufferStatus::Committed + | metal::MTLCommandBufferStatus::Scheduled => {} + metal::MTLCommandBufferStatus::Completed => {} + _ => { + panic!("Command buffer not committed"); + } } - metal::MTLCommandBufferStatus::Committed | metal::MTLCommandBufferStatus::Scheduled => { - old.wait_until_completed(); - } - _ => {} + // println!("Wait {i}"); + command_buffer.wait_until_completed(); + // println!("Ok {i}"); + // command_buffer.wait_until_completed(); } - let command_buffer = self.command_queue.new_command_buffer().to_owned(); - *old = command_buffer; } pub fn kernels(&self) -> &Kernels { @@ -112,28 +144,40 @@ impl MetalDevice { &self.device } - pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Arc { + pub fn new_buffer(&self, element_count: usize, dtype: DType, name: &str) -> Arc { let size = (element_count * dtype.size_in_bytes()) as NSUInteger; - self._new_buffer(size, MTLResourceOptions::StorageModePrivate) + self._new_buffer(size, MTLResourceOptions::StorageModePrivate, name) } - fn _new_buffer(&self, size: NSUInteger, option: MTLResourceOptions) -> Arc { + fn _new_buffer(&self, size: NSUInteger, option: MTLResourceOptions, name: &str) -> Arc { + // println!("Creating new buffer {name}"); let mut buffers = self.buffers.try_write().unwrap(); let subbuffers = buffers.entry((size, option)).or_insert(vec![]); for sub in &mut *subbuffers { if Arc::strong_count(sub) == 1 { - return sub.clone(); + // println!("Reusing tensor {size} {name}"); + // return sub.clone(); } } let new_buffer = self.device.new_buffer(size as NSUInteger, option); let new_buffer = Arc::new(new_buffer); - subbuffers.push(new_buffer.clone()); + // subbuffers.push(new_buffer.clone()); + // println!("Created tensor {size} {name}"); + for subbuffers in buffers.values_mut() { + let newbuffers = subbuffers + .iter() + .filter(|s| Arc::strong_count(s) > 1) + .map(|s| Arc::clone(s)) + .collect(); + *subbuffers = newbuffers; + } + new_buffer } pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc { - self._new_buffer(size, MTLResourceOptions::StorageModeManaged) + self._new_buffer(size, MTLResourceOptions::StorageModeManaged, "managed") } pub fn new_buffer_with_data(&self, data: &[T]) -> Arc { @@ -143,13 +187,20 @@ impl MetalDevice { size, metal::MTLResourceOptions::StorageModeManaged, ); - let real = self._new_buffer(size, metal::MTLResourceOptions::StorageModePrivate); - { - let command = self.command_buffer(); - let blit = command.new_blit_command_encoder(); - blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); - blit.end_encoding(); - } + let real = self._new_buffer( + size, + metal::MTLResourceOptions::StorageModePrivate, + "with_data", + ); + let command = self.command_buffer(); + let blit = command.new_blit_command_encoder(); + blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); + blit.end_encoding(); + command.commit(); + real.did_modify_range(metal::NSRange::new(0, real.length())); + // println!("Command {:?}", command.status()); + + // self.commit(); // This is necessary, for mmaped safetensors // Because of the unsafe slice cast we're doing. // The slice might not live long enough for metal @@ -169,7 +220,7 @@ impl MetalDevice { dtype: DType, ) -> Result<(Matrix, Arc)> { let elem_count = (b * m * n) as usize; - let out_buffer = self.new_buffer(elem_count, dtype); + let out_buffer = self.new_buffer(elem_count, dtype, "matrix"); let result_descriptor = MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id); @@ -241,13 +292,18 @@ impl BackendStorage for MetalStorage { self.dtype ); } - + self.device.wait_until_completed(); + self.buffer + .did_modify_range(metal::NSRange::new(0, self.buffer.length())); let buffer = self.device.new_buffer_managed(self.buffer.length()); - let command_buffer = self.device.command_buffer(); - let blit = command_buffer.new_blit_command_encoder(); - blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); - blit.end_encoding(); - drop(command_buffer); + { + let command_buffer = self.device.command_buffer(); + let blit = command_buffer.new_blit_command_encoder(); + blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); + blit.end_encoding(); + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); + } self.device.wait_until_completed(); match self.dtype { @@ -256,7 +312,11 @@ impl BackendStorage for MetalStorage { DType::I64 => Ok(CpuStorage::I64(buffer.read_to_vec(length / size))), DType::F16 => Ok(CpuStorage::F16(buffer.read_to_vec(length / size))), DType::BF16 => Ok(CpuStorage::BF16(buffer.read_to_vec(length / size))), - DType::F32 => Ok(CpuStorage::F32(buffer.read_to_vec(length / size))), + DType::F32 => { + let vec = buffer.read_to_vec(length / size); + // println!("Got back {:?}", &vec[..1]); + Ok(CpuStorage::F32(vec)) + } DType::F64 => Ok(CpuStorage::F64(buffer.read_to_vec(length / size))), } } @@ -268,7 +328,7 @@ impl BackendStorage for MetalStorage { let el = shape.elem_count(); let dtype = self.dtype; - let buffer = device.new_buffer(el, self.dtype); + let buffer = device.new_buffer(el, self.dtype, "affine"); let command_buffer = self.device.command_buffer(); if layout.is_contiguous() && layout.start_offset() == 0 { let name = match self.dtype { @@ -309,15 +369,111 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } - fn powf(&self, _: &Layout, _: f64) -> Result { - crate::bail!("powf metal") + fn powf(&self, layout: &Layout, pow: f64) -> Result { + let device = self.device().clone(); + + let shape = layout.shape(); + let el = shape.elem_count(); + let dtype = self.dtype; + + let buffer = device.new_buffer(el, self.dtype, "powf"); + let command_buffer = self.device.command_buffer(); + if layout.is_contiguous() && layout.start_offset() == 0 { + let name = match self.dtype { + DType::F32 => "powf_float", + DType::F16 => "powf_half", + dtype => crate::bail!("Powf {dtype:?}"), + }; + candle_metal_kernels::call_powf( + &device.device, + &command_buffer, + &device.kernels, + name, + el, + &self.buffer, + &buffer, + pow as f32, + ) + .map_err(MetalError::from)?; + } else { + let name = match self.dtype { + DType::F32 => "powf_float_strided", + DType::F16 => "powf_half_strided", + dtype => crate::bail!("Powf {dtype:?}"), + }; + candle_metal_kernels::call_powf_strided( + &device.device, + &command_buffer, + &device.kernels, + name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * dtype.size_in_bytes(), + &buffer, + pow as f32, + ) + .map_err(MetalError::from)?; + } + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); + Ok(Self::new(buffer, device.clone(), dtype)) } - fn elu(&self, _: &Layout, _: f64) -> Result { - crate::bail!("elu metal") + fn elu(&self, layout: &Layout, alpha: f64) -> Result { + let device = self.device().clone(); + + let shape = layout.shape(); + let el = shape.elem_count(); + let dtype = self.dtype; + + let buffer = device.new_buffer(el, self.dtype, "elu"); + let command_buffer = self.device.command_buffer(); + if layout.is_contiguous() && layout.start_offset() == 0 { + let name = match self.dtype { + DType::F32 => "elu_float", + DType::F16 => "elu_half", + dtype => crate::bail!("Powf {dtype:?}"), + }; + candle_metal_kernels::call_elu( + &device.device, + &command_buffer, + &device.kernels, + name, + el, + &self.buffer, + &buffer, + alpha as f32, + ) + .map_err(MetalError::from)?; + } else { + let name = match self.dtype { + DType::F32 => "elu_float_strided", + DType::F16 => "elu_half_strided", + dtype => crate::bail!("Powf {dtype:?}"), + }; + candle_metal_kernels::call_elu_strided( + &device.device, + &command_buffer, + &device.kernels, + name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * dtype.size_in_bytes(), + &buffer, + alpha as f32, + ) + .map_err(MetalError::from)?; + } + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); + Ok(Self::new(buffer, device.clone(), dtype)) } fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { @@ -365,7 +521,7 @@ impl BackendStorage for MetalStorage { if dtype == DType::U32 { crate::bail!("Implement return index reduce op"); } - let buffer = device.new_buffer(dst_el, dtype); + let buffer = device.new_buffer(dst_el, dtype, "reduce"); let command_buffer = self.device.command_buffer(); candle_metal_kernels::call_reduce_contiguous( &device.device, @@ -379,6 +535,8 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device, dtype)) } @@ -391,7 +549,7 @@ impl BackendStorage for MetalStorage { let device = self.device(); let shape = layout.shape(); let el_count = shape.elem_count(); - let buffer = device.new_buffer(el_count, dtype); + let buffer = device.new_buffer(el_count, dtype, "todtype"); let command_buffer = device.command_buffer(); if layout.is_contiguous() { let kernel_name = match (self.dtype, dtype) { @@ -435,6 +593,8 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -444,7 +604,7 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let shape = layout.shape(); let el_count = shape.elem_count(); - let buffer = device.new_buffer(el_count, dtype); + let buffer = device.new_buffer(el_count, dtype, B::KERNEL); let command_buffer = device.command_buffer(); if layout.is_contiguous() && layout.start_offset() == 0 { use candle_metal_kernels::unary::contiguous; @@ -463,6 +623,7 @@ impl BackendStorage for MetalStorage { ("uceil", DType::F32) => contiguous::ceil::FLOAT, ("ufloor", DType::F32) => contiguous::floor::FLOAT, ("uround", DType::F32) => contiguous::round::FLOAT, + ("utanh", DType::F32) => contiguous::tanh::FLOAT, ("ucos", DType::F16) => contiguous::cos::HALF, ("usin", DType::F16) => contiguous::sin::HALF, ("usqr", DType::F16) => contiguous::sqr::HALF, @@ -476,6 +637,7 @@ impl BackendStorage for MetalStorage { ("uceil", DType::F16) => contiguous::ceil::HALF, ("ufloor", DType::F16) => contiguous::floor::HALF, ("uround", DType::F16) => contiguous::round::HALF, + ("utanh", DType::F16) => contiguous::tanh::HALF, (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_unary_contiguous( @@ -534,8 +696,8 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } command_buffer.set_label("unary"); - drop(command_buffer); - self.device.commit(); + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -549,30 +711,31 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let shape = lhs_l.shape(); let el_count = shape.elem_count(); - let buffer = device.new_buffer(el_count, dtype); + let buffer = device.new_buffer(el_count, dtype, B::KERNEL); let command_buffer = device.command_buffer(); if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) + && &B::KERNEL[..1] != "b" { use candle_metal_kernels::binary::contiguous; let kernel_name = match (B::KERNEL, dtype) { ("add", DType::F32) => contiguous::add::FLOAT, - ("badd", DType::F32) => contiguous::add::FLOAT, + // ("badd", DType::F32) => contiguous::add::FLOAT, ("sub", DType::F32) => contiguous::sub::FLOAT, - ("bsub", DType::F32) => contiguous::sub::FLOAT, + //("bsub", DType::F32) => contiguous::sub::FLOAT, ("mul", DType::F32) => contiguous::mul::FLOAT, - ("bmul", DType::F32) => contiguous::mul::FLOAT, + // ("bmul", DType::F32) => contiguous::mul::FLOAT, ("div", DType::F32) => contiguous::div::FLOAT, - ("bdiv", DType::F32) => contiguous::div::FLOAT, + // ("bdiv", DType::F32) => contiguous::div::FLOAT, ("add", DType::F16) => contiguous::add::HALF, - ("badd", DType::F16) => contiguous::add::HALF, + // ("badd", DType::F16) => contiguous::add::HALF, ("sub", DType::F16) => contiguous::sub::HALF, - ("bsub", DType::F16) => contiguous::sub::HALF, + // ("bsub", DType::F16) => contiguous::sub::HALF, ("mul", DType::F16) => contiguous::mul::HALF, - ("bmul", DType::F16) => contiguous::mul::HALF, + // ("bmul", DType::F16) => contiguous::mul::HALF, ("div", DType::F16) => contiguous::div::HALF, - ("bdiv", DType::F16) => contiguous::div::HALF, + // ("bdiv", DType::F16) => contiguous::div::HALF, (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_binary_contiguous( @@ -617,8 +780,8 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } command_buffer.set_label("binary"); - drop(command_buffer); - self.device.commit(); + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -635,7 +798,7 @@ impl BackendStorage for MetalStorage { let dims = shape.dims(); let el = shape.elem_count(); let dtype = t.dtype; - let buffer = self.device.new_buffer(el, dtype); + let buffer = self.device.new_buffer(el, dtype, "where"); let command_buffer = self.device.command_buffer(); if t.dtype() != f.dtype() { crate::bail!("Invalid ternary different dtypes for values"); @@ -663,6 +826,8 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device, dtype)) } @@ -752,7 +917,7 @@ impl BackendStorage for MetalStorage { let dst_el = ids_el * left_size * right_size; let dtype = self.dtype; let device = self.device(); - let buffer = device.new_buffer(dst_el, dtype); + let buffer = device.new_buffer(dst_el, dtype, "index_select"); let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "is_u32_f32", (DType::U32, DType::F16) => "is_u32_f16", @@ -772,6 +937,8 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -887,9 +1054,9 @@ impl BackendStorage for MetalStorage { &result_matrix, ); command_buffer.set_label("matmul"); - drop(command_buffer); - self.device.commit(); - + command_buffer.commit(); + out_buffer.did_modify_range(metal::NSRange::new(0, out_buffer.length())); + // println!("========= MATMUL {:?}", Arc::strong_count(&out_buffer)); Ok(Self::new(out_buffer, self.device.clone(), self.dtype())) } @@ -899,14 +1066,9 @@ impl BackendStorage for MetalStorage { command_buffer.set_label("copy_contiguous"); let blit = command_buffer.new_blit_command_encoder(); let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger; + let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger; let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger; - blit.copy_from_buffer( - &self.buffer, - src_offset, - dst.buffer(), - dst_offset, - self.buffer.length() - src_offset, - ); + blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length); blit.end_encoding(); } else { let src_shape = src_l.shape(); @@ -937,8 +1099,9 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; command_buffer.set_label("copy_strided"); } - drop(command_buffer); - self.device.commit(); + command_buffer.commit(); + dst.buffer + .did_modify_range(metal::NSRange::new(0, dst.buffer.length())); Ok(()) } } @@ -968,22 +1131,22 @@ impl MetalStorage { ) -> Result { let key = (b, m, n, transpose, size, offset, type_id); - let mut matrices = self.matrices.try_write().unwrap(); - if let Some(matrix) = matrices.get(&key) { - Ok(matrix.clone()) + // let mut matrices = self.matrices.try_write().unwrap(); + // if let Some(matrix) = matrices.get(&key) { + // Ok(matrix.clone()) + // } else { + let descriptor = if transpose { + MatrixDescriptor::init_multiple(n, m, b, m * size, m * n * size, type_id) } else { - let descriptor = if transpose { - MatrixDescriptor::init_multiple(n, m, b, m * size, m * n * size, type_id) - } else { - MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id) - }; - let matrix = Matrix::init_with_buffer_descriptor(&self.buffer, offset, &descriptor) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; - matrices.insert(key, matrix.clone()); - Ok(matrix) - } + MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id) + }; + let matrix = Matrix::init_with_buffer_descriptor(&self.buffer, offset, &descriptor) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; + // matrices.insert(key, matrix.clone()); + Ok(matrix) + // } } } @@ -991,16 +1154,28 @@ impl BackendDevice for MetalDevice { type Storage = MetalStorage; fn new(ordinal: usize) -> Result { + // println!("CREATING DEVICE"); let device = metal::Device::all().swap_remove(ordinal); + let n = 50; let command_queue = device.new_command_queue(); - let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned())); + + let command_buffers = (0..n) + .map(|_| { + let command_buffer = command_queue.new_command_buffer().to_owned(); + command_buffer.enqueue(); + command_buffer + }) + .collect(); + let command_buffers = Arc::new(RwLock::new(command_buffers)); + let command_buffer_index = Arc::new(RwLock::new(0)); let kernels = Arc::new(Kernels::new()); let buffers = Arc::new(RwLock::new(HashMap::new())); Ok(Self { device, command_queue, - command_buffer, + command_buffers, + command_buffer_index, buffers, kernels, }) @@ -1021,7 +1196,21 @@ impl BackendDevice for MetalDevice { } fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { - let buffer = self.new_buffer(shape.elem_count(), dtype); + let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros"); + let command_buffer = self.command_buffer(); + let blit = command_buffer.new_blit_command_encoder(); + blit.fill_buffer( + &buffer, + metal::NSRange { + location: 0, + length: buffer.length(), + }, + 0, + ); + blit.end_encoding(); + + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(MetalStorage::new(buffer, self.clone(), dtype)) } diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 87323a84..73a0cc7a 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1864,7 +1864,7 @@ impl Tensor { } (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), (Storage::Metal(storage), Device::Cpu) => { - println!("{storage:?} - {:?}", storage.to_cpu_storage()?); + // println!("{storage:?} - {:?}", storage.to_cpu_storage()?); Storage::Cpu(storage.to_cpu_storage()?) } (Storage::Cuda(storage), Device::Cuda(cuda)) => { diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index a08bfbc0..18adb457 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -29,9 +29,7 @@ kernel void FN_NAME( \ if (id >= dim) { \ return; \ } \ - const TYPENAME m = TYPENAME(mul); \ - const TYPENAME a = TYPENAME(add); \ - output[id] = input[id] * m + a; \ + output[id] = TYPENAME(float(input[id]) * mul + add); \ } \ kernel void FN_NAME##_strided( \ constant size_t &dim, \ @@ -47,15 +45,80 @@ kernel void FN_NAME##_strided( \ if (id >= dim) { \ return; \ } \ - const TYPENAME m = TYPENAME(mul); \ - const TYPENAME a = TYPENAME(add); \ - output[id] = input[get_strided_index(id, num_dims, dims, strides)] * m + a; \ + output[id] = TYPENAME(float(input[get_strided_index(id, num_dims, dims, strides)]) * mul + add); \ +} + +#define POWF(FN_NAME, TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + constant float &mul, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + output[id] = TYPENAME(pow(input[id], TYPENAME(mul))); \ } \ +kernel void FN_NAME##_strided( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant float &mul, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + output[id] = TYPENAME(pow(input[get_strided_index(id, num_dims, dims, strides)], TYPENAME(mul))); \ +} + +#define ELU(FN_NAME, TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + constant float &mul, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + const TYPENAME x = input[id]; \ + output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \ +} \ +kernel void FN_NAME##_strided( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant float &mul, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + const TYPENAME x = input[get_strided_index(id, num_dims, dims, strides)]; \ + output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \ +} \ + AFFINE(affine_float, float) AFFINE(affine_half, half) +POWF(powf_float, float) +POWF(powf_half, half) +ELU(elu_float, float) +ELU(elu_half, half) #if __METAL_VERSION__ >= 310 AFFINE(affine_bfloat, bfloat); +POWF(powf_bfloat, bfloat); +ELU(elu_bfloat, bfloat); #endif diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index a0b852a4..237bd858 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -153,7 +153,7 @@ macro_rules! ops{ } pub mod unary { - ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf); + ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh); } pub mod binary { ops!(add, sub, mul, div); @@ -616,6 +616,130 @@ pub fn call_affine_strided( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_powf( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + size: usize, + input: &Buffer, + output: &Buffer, + mul: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (size, mul, input, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_powf_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + input: &Buffer, + input_stride: &[usize], + input_offset: usize, + output: &Buffer, + mul: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + let size: usize = shape.iter().product(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + size, + shape.len(), + shape, + input_stride, + mul, + (input, input_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_elu( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + size: usize, + input: &Buffer, + output: &Buffer, + mul: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (size, mul, input, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_elu_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + input: &Buffer, + input_stride: &[usize], + input_offset: usize, + output: &Buffer, + mul: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + let size: usize = shape.iter().product(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + size, + shape.len(), + shape, + input_stride, + mul, + (input, input_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + pub fn call_where_cond_strided( device: &Device, command_buffer: &CommandBufferRef, diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 867877fb..3a402427 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -18,7 +18,7 @@ METAL_FUNC uint get_strided_index( return strided_i; } -constant int THREADGROUP_SIZE = 1024; +constant int THREADGROUP_SIZE = 2048; # define REDUCE(FN, NAME, T) \ kernel void NAME( \ diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 529162bd..765b14a5 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -69,7 +69,7 @@ kernel void FN_NAME( \ if (thread_position_in_grid >= dim) { \ return; \ } \ - output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \ + output[thread_position_in_grid] = TYPENAME(FN(float(input[thread_position_in_grid]))); \ }\ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -83,7 +83,7 @@ kernel void FN_NAME_STRIDED( \ if (thread_position_in_grid >= dim) { \ return; \ } \ - output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \ + output[thread_position_in_grid] = TYPENAME(FN(float(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)]))); \ } #define UNARY_OP(NAME) \ @@ -107,6 +107,7 @@ UNARY_OP(floor) UNARY_OP(round) UNARY_OP(gelu_erf) UNARY_OP(erf) +UNARY_OP(tanh) UNARY(id, float, copy_float, copy_float_strided) UNARY(id, half, copy_half, copy_half_strided) UNARY(id, uint8_t, copy_u8, copy_u8_strided) @@ -126,6 +127,7 @@ BFLOAT_UNARY_OP(floor) BFLOAT_UNARY_OP(round) BFLOAT_UNARY_OP(gelu_erf) BFLOAT_UNARY_OP(erf) +BFLOAT_UNARY_OP(tanh) UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided) #endif diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index 45298907..03622752 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -19,6 +19,7 @@ num-traits = { workspace = true } rayon = { workspace = true } safetensors = { workspace = true } serde = { workspace = true } +metal = { workspace = true, optional = true } candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true } [dev-dependencies] @@ -30,4 +31,4 @@ default = [] accelerate = ["dep:accelerate-src", "candle/accelerate"] cuda = ["candle/cuda"] mkl = ["dep:intel-mkl-src", "candle/mkl"] -metal = ["candle/metal", "dep:candle-metal-kernels"] +metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"] diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 350bc663..14dd10de 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -226,7 +226,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { let last_dim = layout.dims()[layout.shape().rank() - 1]; let elem_count = layout.shape().elem_count(); - let mut output = device.new_buffer(elem_count, storage.dtype()); + let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax"); candle_metal_kernels::call_last_softmax( device.metal_device(), &command_buffer, @@ -238,6 +238,8 @@ impl candle::CustomOp1 for SoftmaxLastDim { &mut output, ) .unwrap(); + command_buffer.commit(); + output.did_modify_range(metal::NSRange::new(0, output.length())); let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype()); Ok((newstorage, layout.shape().clone())) } diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index af4e04b7..e72cab69 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -31,3 +31,4 @@ accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"] cuda = ["candle/cuda", "candle-nn/cuda"] flash-attn = ["cuda", "dep:candle-flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"] +metal = ["candle/metal", "candle-nn/metal"] diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index e822ca14..c8dae511 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -142,10 +142,9 @@ impl RotaryEmbedding { .to_dtype(DType::F32)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; - Ok(Self { - sin: freqs.sin()?, - cos: freqs.cos()?, - }) + let sin = freqs.sin()?; + let cos = freqs.cos()?; + Ok(Self { sin, cos }) } fn apply_rotary_emb_qkv( @@ -273,6 +272,10 @@ impl MHA { } fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result { + let view = xs.to_string(); + if view.contains("NaN") { + panic!("NaN"); + } let _enter = self.span.enter(); let (b_size, seq_len, _n_embd) = xs.dims3()?; let qkv = self @@ -408,3 +411,38 @@ impl MixFormerSequentialForCausalLM { self.blocks.iter_mut().for_each(|b| b.clear_kv_cache()) } } + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_rotary() { + let dev = Device::new_metal(0).unwrap(); + for i in 0..10000 { + let dim = 8; + let max_seq_len = 12; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), &dev).unwrap(); + let t = Tensor::arange(0u32, max_seq_len as u32, &dev) + .unwrap() + .to_dtype(DType::F32) + .unwrap() + .reshape((max_seq_len, 1)) + .unwrap(); + let x: f32 = t.i((1, 0)).unwrap().to_scalar().unwrap(); + assert_eq!(x, 1.0); + let x: f32 = inv_freq.i((0, 1)).unwrap().to_scalar().unwrap(); + assert_eq!(x, 0.1); + let freqs = t.matmul(&inv_freq).unwrap(); + let x: f32 = freqs.i((1, 1)).unwrap().to_scalar().unwrap(); + assert_eq!(x, 0.1); + let sin = freqs.sin().unwrap().contiguous().unwrap(); + let x: f32 = sin.i((1, 1)).unwrap().to_scalar().unwrap(); + assert_eq!(x, 0.099833414); + } + } +} -- cgit v1.2.3 From 361f2ad2af52ccf1750e274f1649fb8c90f80a86 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 14 Dec 2023 16:05:33 +0100 Subject: Working with merging encoders and using fences. --- candle-core/src/metal_backend.rs | 120 ++++++--------------- candle-core/tests/tensor_tests.rs | 2 + candle-metal-kernels/src/lib.rs | 40 ++++++- candle-metal-kernels/src/test.swift | 209 ++++++++++++++++++++++++++++++++++++ candle-nn/src/ops.rs | 2 - 5 files changed, 279 insertions(+), 94 deletions(-) create mode 100644 candle-metal-kernels/src/test.swift (limited to 'candle-nn') diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 9866f1ca..4bc80823 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -38,6 +38,7 @@ pub struct MetalDevice { command_queue: metal::CommandQueue, command_buffers: Arc>>, command_buffer_index: Arc>, + fence: metal::Fence, kernels: Arc, buffers: Arc>>>>, } @@ -71,68 +72,32 @@ impl MetalDevice { pub fn command_buffer(&self) -> CommandBuffer { let mut command_buffers = self.command_buffers.try_write().unwrap(); + let mut command_buffer = command_buffers[0].to_owned(); let mut index = self.command_buffer_index.try_write().unwrap(); - let n = command_buffers.len(); - if *index == n { - // todo!("Cycle buffers"); - for i in 0..n { - let command_buffer = &command_buffers[i]; - match command_buffer.status() { - metal::MTLCommandBufferStatus::Committed - | metal::MTLCommandBufferStatus::Scheduled => { - // println!("Wait during cycling {i}"); - // println!("Command {i} / {n}: {:?}", command_buffer.status()); - command_buffer.wait_until_completed(); - } - metal::MTLCommandBufferStatus::Completed => {} - _ => { - panic!("Command buffer {i} not committed during cycling"); - } - } - } - let new_buffers = (0..n) - .map(|i| { - // println!("Creating command buffer {i}"); - let command_buffer = self.command_queue.new_command_buffer().to_owned(); - command_buffer.set_label(&format!("num {i}")); - command_buffer.enqueue(); - command_buffer - }) - .collect(); - *command_buffers = new_buffers; + if *index > 20 { + command_buffer.commit(); + command_buffer = self.command_queue.new_command_buffer().to_owned(); + *command_buffers = vec![command_buffer.clone()]; *index = 0; - // println!("Reset"); } - // println!("Giving buffer {} / {n}", *index); - let out = &command_buffers[*index]; - assert_eq!(out.status(), metal::MTLCommandBufferStatus::Enqueued); *index += 1; - out.to_owned() + command_buffer } pub fn wait_until_completed(&self) { - let command_buffers = self.command_buffers.try_write().unwrap(); - let index = self.command_buffer_index.try_write().unwrap(); - // let n = command_buffers.len(); - // for i in 0..*index { - // let command_buffer = &command_buffers[i]; - // println!("Command {i} / {n}: {:?}", command_buffer.status()); - // } - for i in 0..*index { - let command_buffer = &command_buffers[i]; - match command_buffer.status() { - metal::MTLCommandBufferStatus::Committed - | metal::MTLCommandBufferStatus::Scheduled => {} - metal::MTLCommandBufferStatus::Completed => {} - _ => { - panic!("Command buffer not committed"); - } + let mut command_buffers = self.command_buffers.try_write().unwrap(); + let command_buffer = &command_buffers[0]; + match command_buffer.status() { + metal::MTLCommandBufferStatus::Committed + | metal::MTLCommandBufferStatus::Scheduled + | metal::MTLCommandBufferStatus::Completed => { + panic!("Alredy committed"); } - // println!("Wait {i}"); - command_buffer.wait_until_completed(); - // println!("Ok {i}"); - // command_buffer.wait_until_completed(); + _ => {} } + command_buffer.commit(); + command_buffer.wait_until_completed(); + *command_buffers = vec![self.command_queue.new_command_buffer().to_owned()]; } pub fn kernels(&self) -> &Kernels { @@ -176,7 +141,7 @@ impl MetalDevice { } pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc { - self._new_buffer(size, MTLResourceOptions::StorageModeShared, "managed") + self._new_buffer(size, MTLResourceOptions::StorageModeManaged, "managed") } pub fn new_buffer_with_data(&self, data: &[T]) -> Arc { @@ -184,7 +149,7 @@ impl MetalDevice { let tmp = self.device.new_buffer_with_data( data.as_ptr() as *const core::ffi::c_void, size, - metal::MTLResourceOptions::StorageModeShared, + metal::MTLResourceOptions::StorageModeManaged, ); let real = self._new_buffer( size, @@ -194,15 +159,15 @@ impl MetalDevice { let command_buffer = self.command_buffer(); command_buffer.set_label("with_data"); let blit = command_buffer.new_blit_command_encoder(); + blit.wait_for_fence(&self.fence); blit.set_label("with_data_blit"); blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); + blit.update_fence(&self.fence); blit.end_encoding(); - command_buffer.commit(); - drop(command_buffer); + // drop(command_buffer); // real.did_modify_range(metal::NSRange::new(0, real.length())); // println!("Command {:?}", command.status()); - // self.commit(); // This is necessary, for mmaped safetensors // Because of the unsafe slice cast we're doing. // The slice might not live long enough for metal @@ -259,19 +224,16 @@ impl BackendStorage for MetalStorage { self.dtype ); } - self.device.wait_until_completed(); - self.buffer - .did_modify_range(metal::NSRange::new(0, self.buffer.length())); let buffer = self.device.new_buffer_managed(self.buffer.length()); { let command_buffer = self.device.command_buffer(); command_buffer.set_label("to_cpu"); let blit = command_buffer.new_blit_command_encoder(); blit.set_label("blit_to_cpu"); + blit.wait_for_fence(&self.device.fence); blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); + blit.update_fence(&self.device.fence); blit.end_encoding(); - - command_buffer.commit(); } self.device.wait_until_completed(); @@ -338,8 +300,7 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); + // buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -389,8 +350,6 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -440,7 +399,6 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - command_buffer.commit(); buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -504,8 +462,6 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device, dtype)) } @@ -519,7 +475,6 @@ impl BackendStorage for MetalStorage { let shape = layout.shape(); let el_count = shape.elem_count(); let buffer = device.new_buffer(el_count, dtype, "todtype"); - device.wait_until_completed(); let command_buffer = device.command_buffer(); if layout.is_contiguous() && layout.start_offset() == 0 { let kernel_name = match (self.dtype, dtype) { @@ -564,10 +519,6 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } command_buffer.set_label("to_dtype"); - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); - device.wait_until_completed(); - Ok(Self::new(buffer, device.clone(), dtype)) } @@ -668,8 +619,6 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -752,8 +701,6 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } command_buffer.set_label("binary"); - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -798,8 +745,6 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device, dtype)) } @@ -909,8 +854,6 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -963,8 +906,6 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; // Create kernel - command_buffer.commit(); - self.device.wait_until_completed(); Ok(Self::new(buffer, self.device.clone(), self.dtype())) } @@ -1010,7 +951,6 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; command_buffer.set_label("copy_strided"); } - command_buffer.commit(); Ok(()) } } @@ -1036,7 +976,7 @@ impl BackendDevice for MetalDevice { // println!("CREATING DEVICE"); let device = metal::Device::all().swap_remove(ordinal); - let n = 64; + let n = 1; let command_queue = device.new_command_queue(); let command_buffers = (0..n) @@ -1049,10 +989,12 @@ impl BackendDevice for MetalDevice { .collect(); let command_buffers = Arc::new(RwLock::new(command_buffers)); let command_buffer_index = Arc::new(RwLock::new(0)); - let kernels = Arc::new(Kernels::new()); + let fence = device.new_fence(); + let kernels = Arc::new(Kernels::new(fence.clone())); let buffers = Arc::new(RwLock::new(HashMap::new())); Ok(Self { device, + fence, command_queue, command_buffers, command_buffer_index, @@ -1089,8 +1031,6 @@ impl BackendDevice for MetalDevice { 0, ); blit.end_encoding(); - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(MetalStorage::new(buffer, self.clone(), dtype)) } diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index c871dc96..a77f9c3a 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -900,7 +900,9 @@ fn matmul(device: &Device) -> Result<()> { let b = Tensor::from_slice(&data, (2, 2), device)?; let c = a.matmul(&b)?; + let d = a.matmul(&c)?; assert_eq!(c.to_vec2::()?, &[[7.0f32, 10.0], [15.0, 22.0]]); + assert_eq!(d.to_vec2::()?, &[[37.0, 54.0], [81.0, 118.0]]); let data = vec![1.0f32, 2.0]; let a = Tensor::from_slice(&data, (2, 1), device)?; diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index b80dcb79..01432ccb 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -184,19 +184,21 @@ impl From> for MetalKernelError { type Libraries = HashMap; type Pipelines = HashMap<(&'static str, Option), ComputePipelineState>; -#[derive(Debug, Default)] +#[derive(Debug)] pub struct Kernels { libraries: RwLock, pipelines: RwLock, + fence: metal::Fence, } impl Kernels { - pub fn new() -> Self { + pub fn new(fence: metal::Fence) -> Self { let libraries = RwLock::new(Libraries::new()); let pipelines = RwLock::new(Pipelines::new()); Self { libraries, pipelines, + fence, } } @@ -304,12 +306,14 @@ pub fn call_unary_contiguous( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -331,6 +335,7 @@ pub fn call_unary_strided( let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -350,6 +355,7 @@ pub fn call_unary_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -368,6 +374,7 @@ pub fn call_binary_contiguous( let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, left, right, output)); @@ -375,6 +382,7 @@ pub fn call_binary_contiguous( let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -399,6 +407,7 @@ pub fn call_binary_strided( let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); let width: usize = shape.iter().product(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -420,6 +429,7 @@ pub fn call_binary_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -438,12 +448,14 @@ pub fn call_cast_contiguous( let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, (input, input_offset), output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -463,6 +475,7 @@ pub fn call_cast_strided( let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -482,6 +495,7 @@ pub fn call_cast_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -501,6 +515,7 @@ pub fn call_reduce_contiguous( let elements_to_sum = length / out_length; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -527,6 +542,7 @@ pub fn call_reduce_contiguous( }; encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -544,6 +560,7 @@ pub fn call_last_softmax( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, elements_to_sum, input, output)); @@ -569,6 +586,7 @@ pub fn call_last_softmax( }; encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -588,12 +606,14 @@ pub fn call_affine( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, add, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -616,6 +636,7 @@ pub fn call_affine_strided( let size: usize = shape.iter().product(); let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -634,6 +655,7 @@ pub fn call_affine_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -652,12 +674,14 @@ pub fn call_powf( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -679,6 +703,7 @@ pub fn call_powf_strided( let size: usize = shape.iter().product(); let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -696,6 +721,7 @@ pub fn call_powf_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -714,12 +740,14 @@ pub fn call_elu( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -741,6 +769,7 @@ pub fn call_elu_strided( let size: usize = shape.iter().product(); let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -758,6 +787,7 @@ pub fn call_elu_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -779,6 +809,7 @@ pub fn call_where_cond_strided( let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let size: usize = shape.iter().product(); @@ -803,6 +834,7 @@ pub fn call_where_cond_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -829,6 +861,7 @@ pub fn call_index_select( let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -848,6 +881,7 @@ pub fn call_index_select( let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -1045,6 +1079,7 @@ pub fn call_gemm( let block_bytes = block_elements * bytes; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); // println!("Threadgroup {block_bytes}"); encoder.set_threadgroup_memory_length(0, block_bytes.into()); @@ -1087,6 +1122,7 @@ pub fn call_gemm( }; // println!("grid size {grid_size:?} group size {group_size:?}"); encoder.dispatch_thread_groups(grid_size, group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) diff --git a/candle-metal-kernels/src/test.swift b/candle-metal-kernels/src/test.swift new file mode 100644 index 00000000..f9bb9f91 --- /dev/null +++ b/candle-metal-kernels/src/test.swift @@ -0,0 +1,209 @@ + +import Metal +import MetalPerformanceShadersGraph + + + +let type = MTLDataType.float; +let dataType = type; +var B = 2; +var M = 2; +var N = 2; +var K = 2; +var A_trans = false; +var B_trans = false; +var D_trans = false; +var alpha = Float(1.0); +var beta = Float(0.0); +var batched = B > 1; +var fused_activation = false; +var fused_bias = false; +let constants = MTLFunctionConstantValues() +constants.setConstantValue(&M, type: .uint, index: 0) +constants.setConstantValue(&N, type: .uint, index: 1) +constants.setConstantValue(&K, type: .uint, index: 2) +constants.setConstantValue(&A_trans, type: .bool, index: 10) +constants.setConstantValue(&B_trans, type: .bool, index: 11) +constants.setConstantValue(&D_trans, type: .bool, index: 13) +constants.setConstantValue(&alpha, type: .float, index: 20) +constants.setConstantValue(&beta, type: .float, index: 21) +constants.setConstantValue(&batched, type: .bool, index: 100) +constants.setConstantValue(&fused_activation, type: .bool, index: 101) +constants.setConstantValue(&fused_bias, type: .bool, index: 50001) + + +var M_simd = UInt16(16) +var N_simd = UInt16(16) +var K_simd = UInt16(32) +var M_splits = UInt16(2) +var N_splits = UInt16(2) +constants.setConstantValue(&M_simd, type: .ushort, index: 200) +constants.setConstantValue(&N_simd, type: .ushort, index: 201) +constants.setConstantValue(&K_simd, type: .ushort, index: 202) +constants.setConstantValue(&M_splits, type: .ushort, index: 210) +constants.setConstantValue(&N_splits, type: .ushort, index: 211) + +let M_group = M_simd * M_splits +let N_group = N_simd * N_splits + +// Satisfy Metal API validation. +#if DEBUG +do { + var garbage: SIMD4 = .zero + constants.setConstantValue(&garbage, type: .bool, index: 102) + constants.setConstantValue(&garbage, type: .bool, index: 103) + constants.setConstantValue(&garbage, type: .bool, index: 113) + constants.setConstantValue(&garbage, type: .bool, index: 50000) +} +#endif + +let device = MTLCopyAllDevices().first! +device.shouldMaximizeConcurrentCompilation = true + +var libraryURL = URL.init(string: "/Users/nicolas/src/candle/candle-metal-kernels/")!; +libraryURL.append(component: "src") +libraryURL.append(component: "libMetalFlashAttention.metallib") +let library = try! device.makeLibrary(URL: libraryURL) + +var name: String + switch dataType { + case .half: name = "hgemm" + case .float: name = "sgemm" + default: fatalError() + } +let function = try! library.makeFunction( + name: name, constantValues: constants) + +let A_block_length = M_group * K_simd +let B_block_length = K_simd * N_group + +var blockElements = A_block_length + B_block_length; +if (M % 8 != 0) && (N % 8 != 0) { + let C_block_length = M_group * N_group; + blockElements = max(C_block_length, blockElements) +} +if fused_bias { + if D_trans { + blockElements = max(blockElements, M_group) + } else { + blockElements = max(blockElements, N_group) + } +} +// let blockBytes = blockElements * UInt16(dataType.size) +let elementSize = 4 +let blockBytes = blockElements * UInt16(elementSize) + +func ceilDivide(target: Int, granularity: UInt16) -> Int { + (target + Int(granularity) - 1) / Int(granularity) +} +var gridSize = MTLSize( + width: ceilDivide(target: N, granularity: N_group), + height: ceilDivide(target: M, granularity: M_group), + depth: 1) +let groupSize = MTLSize( + width: Int(32 * M_splits * N_splits), + height: 1, + depth: 1) + +let commandQueue = device.makeCommandQueue()! + +let threadgroupMemoryLength = blockBytes; + +let rowsA = M; +let columnsA = K; +let rowsB = K; +let columnsB = N; +let rowsC = M; +let columnsC = N; +var arrayA = [Float](repeating: 0, count: B * rowsA * columnsA) + +var arrayB = [Float](repeating: 0, count: B * rowsB * columnsB) + +var arrayC = [Float](repeating: 0, count: B * rowsC * columnsC) +var arrayD = [Float](repeating: 0, count: B * rowsC * columnsC) +for i in 0...stride, options: [])! + +let bufferB = device.makeBuffer(bytes: arrayB, length: B * rowsB * columnsB * MemoryLayout.stride, options: [])! + +let bufferC = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout.stride, options: [])! +let bufferD = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout.stride, options: [])! + + +let pipeline = try device.makeComputePipelineState(function: function) + +func call(bufferA: MTLBuffer, bufferB: MTLBuffer, bufferC: MTLBuffer){ + let encoder = commandBuffer.makeComputeCommandEncoder(dispatchType: MTLDispatchType.serial)! + encoder.setComputePipelineState(pipeline) + encoder.setThreadgroupMemoryLength(Int(threadgroupMemoryLength), index: 0) + + encoder.setBuffer(bufferA, offset: 0, index: 0) + encoder.setBuffer(bufferB, offset: 0, index: 1) + encoder.setBuffer(bufferC, offset: 0, index: 2) + let gridZ: Int = B + if batched{ + func byteStride(shape: [Int]) -> Int { + let rank = shape.count + var output = elementSize * shape[rank - 2] * shape[rank - 1] + if shape.dropLast(2).reduce(1, *) == 1 { + output = 0 + } + return output + } + let byteStrideA = M*K*elementSize + let byteStrideB = N*K*elementSize + let byteStrideC = M*N*elementSize + + let byteStrideD = 0 + withUnsafeTemporaryAllocation( + of: SIMD4.self, capacity: gridZ + ) { buffer in + for i in 0..>.stride + assert(MemoryLayout>.stride == 8 * 4) + encoder.setBytes(buffer.baseAddress!, length: bufferLength, index: 10) + } + } + gridSize.depth = gridZ + + + encoder.dispatchThreadgroups( + gridSize, threadsPerThreadgroup: groupSize + ) + encoder.endEncoding() +} + +var commandBuffer = commandQueue.makeCommandBuffer()! +call(bufferA:bufferA, bufferB:bufferB, bufferC:bufferC) +commandBuffer.commit() +commandBuffer = commandQueue.makeCommandBuffer()! +commandBuffer.encodeWaitForEvent(event, value: 2) +call(bufferA:bufferA, bufferB:bufferC, bufferC:bufferD) +commandBuffer.commit() + +commandBuffer.waitUntilCompleted() +var contents = bufferC.contents(); +var count = B * rowsA * columnsB; +var typedPointer = contents.bindMemory(to: Float.self, capacity: count) +var bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count) +print("First matmul is OK", Array(bufferedPointer)) + +contents = bufferD.contents(); +count = B * rowsA * columnsB; +typedPointer = contents.bindMemory(to: Float.self, capacity: count) +bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count) +print("This should be filled", Array(bufferedPointer)) diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 14dd10de..e002d931 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -238,8 +238,6 @@ impl candle::CustomOp1 for SoftmaxLastDim { &mut output, ) .unwrap(); - command_buffer.commit(); - output.did_modify_range(metal::NSRange::new(0, output.length())); let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype()); Ok((newstorage, layout.shape().clone())) } -- cgit v1.2.3 From ece4c69a681215837fd5a008e2ee652394daa8ed Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 15 Dec 2023 01:35:08 +0100 Subject: Fixing softmax. --- candle-core/src/metal_backend.rs | 10 ++++++---- candle-metal-kernels/src/reduce.metal | 11 +++++++---- candle-nn/src/ops.rs | 2 +- candle-transformers/src/models/mixformer.rs | 4 ---- 4 files changed, 14 insertions(+), 13 deletions(-) (limited to 'candle-nn') diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index d38796a1..b8b951f0 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -113,21 +113,23 @@ impl MetalDevice { self._new_buffer(size, MTLResourceOptions::StorageModePrivate, name) } - fn _new_buffer(&self, size: NSUInteger, option: MTLResourceOptions, name: &str) -> Arc { - // println!("Creating new buffer {name}"); + fn _new_buffer( + &self, + size: NSUInteger, + option: MTLResourceOptions, + _name: &str, + ) -> Arc { let mut buffers = self.buffers.try_write().unwrap(); let subbuffers = buffers.entry((size, option)).or_insert(vec![]); for sub in &mut *subbuffers { if Arc::strong_count(sub) == 1 { - // println!("Reusing tensor {size} {name}"); return sub.clone(); } } let new_buffer = self.device.new_buffer(size as NSUInteger, option); let new_buffer = Arc::new(new_buffer); subbuffers.push(new_buffer.clone()); - // println!("Created tensor {size} {name}"); for subbuffers in buffers.values_mut() { let newbuffers = subbuffers .iter() diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 53e4664a..3633fdcf 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -67,7 +67,6 @@ kernel void NAME( \ threadgroup_barrier(mem_flags::mem_none); \ } \ \ - threadgroup_barrier(mem_flags::mem_none); \ dst[dst_id] = shared_memory[0]; \ } \ @@ -94,11 +93,10 @@ kernel void NAME( size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \ size_t idx = start_idx + tid; \ \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ \ - float tmp = 0; \ + float tmp = -INFINITY; \ while (idx < stop_idx) { \ - tmp = MAX(tmp, src[idx]); \ + tmp = MAX(tmp, float(src[idx])); \ idx += block_dim; \ } \ shared_memory[tid] = tmp; \ @@ -109,12 +107,15 @@ kernel void NAME( if (tid < s) { \ shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \ } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ } \ \ + /* wait for shared_memory[0] to be filled */ \ threadgroup_barrier(mem_flags::mem_threadgroup); \ \ float _max = shared_memory[0]; \ \ + /* prevent tid=0 from overwriting _max before other threads have written it */ \ threadgroup_barrier(mem_flags::mem_threadgroup); \ shared_memory[tid] = 0; \ \ @@ -125,10 +126,12 @@ kernel void NAME( shared_memory[tid] += val; \ idx += block_dim; \ } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ for (uint s = block_dim / 2; s > 0; s >>= 1) { \ if (tid < s) { \ shared_memory[tid] += shared_memory[tid + s]; \ } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ } \ \ const T inv_acc = T(1.0/shared_memory[0]); \ diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index e002d931..f00d8e2f 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -220,7 +220,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { }; let n = layout.stride().len(); - if !(layout.stride()[n - 1] == 1 && layout.start_offset() == 0) { + if !(layout.is_contiguous() && layout.stride()[n - 1] == 1 && layout.start_offset() == 0) { candle::bail!("Non contiguous softmax-last-dim is not implemented"); } diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index 3f9aa47d..e4e4f619 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -272,10 +272,6 @@ impl MHA { } fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result { - // let view = xs.to_string(); - // if view.contains("NaN") { - // panic!("NaN"); - // } let _enter = self.span.enter(); let (b_size, seq_len, _n_embd) = xs.dims3()?; let qkv = self -- cgit v1.2.3 From 26540641c1f0a7b351f5e3d3c3c165221ae1d9ed Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 15 Dec 2023 11:24:47 +0100 Subject: Renamed all kernel names. --- candle-core/src/metal_backend.rs | 34 +++++++++++++++++----------------- candle-metal-kernels/src/affine.metal | 18 +++++++++--------- candle-metal-kernels/src/binary.metal | 6 +++--- candle-metal-kernels/src/lib.rs | 24 ++++++++++++------------ candle-metal-kernels/src/reduce.metal | 12 ++++++------ candle-metal-kernels/src/unary.metal | 12 ++++++------ candle-nn/src/ops.rs | 6 +++--- 7 files changed, 56 insertions(+), 56 deletions(-) (limited to 'candle-nn') diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index d8518b3e..b4a490cd 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -314,8 +314,8 @@ impl BackendStorage for MetalStorage { let command_buffer = self.device.command_buffer(); if layout.is_contiguous() && layout.start_offset() == 0 { let name = match self.dtype { - DType::F32 => "affine_float", - DType::F16 => "affine_half", + DType::F32 => "affine_f32", + DType::F16 => "affine_f16", dtype => crate::bail!("Affine {dtype:?}"), }; candle_metal_kernels::call_affine( @@ -332,8 +332,8 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } else { let name = match self.dtype { - DType::F32 => "affine_float_strided", - DType::F16 => "affine_half_strided", + DType::F32 => "affine_f32_strided", + DType::F16 => "affine_f16_strided", dtype => crate::bail!("Affine {dtype:?}"), }; candle_metal_kernels::call_affine_strided( @@ -365,8 +365,8 @@ impl BackendStorage for MetalStorage { let command_buffer = self.device.command_buffer(); if layout.is_contiguous() && layout.start_offset() == 0 { let name = match self.dtype { - DType::F32 => "powf_float", - DType::F16 => "powf_half", + DType::F32 => "powf_f32", + DType::F16 => "powf_f16", dtype => crate::bail!("Powf {dtype:?}"), }; candle_metal_kernels::call_powf( @@ -382,8 +382,8 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } else { let name = match self.dtype { - DType::F32 => "powf_float_strided", - DType::F16 => "powf_half_strided", + DType::F32 => "powf_f32_strided", + DType::F16 => "powf_f16_strided", dtype => crate::bail!("Powf {dtype:?}"), }; candle_metal_kernels::call_powf_strided( @@ -414,8 +414,8 @@ impl BackendStorage for MetalStorage { let command_buffer = self.device.command_buffer(); if layout.is_contiguous() && layout.start_offset() == 0 { let name = match self.dtype { - DType::F32 => "elu_float", - DType::F16 => "elu_half", + DType::F32 => "elu_f32", + DType::F16 => "elu_f16", dtype => crate::bail!("Powf {dtype:?}"), }; candle_metal_kernels::call_elu( @@ -431,8 +431,8 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } else { let name = match self.dtype { - DType::F32 => "elu_float_strided", - DType::F16 => "elu_half_strided", + DType::F32 => "elu_f32_strided", + DType::F16 => "elu_f16_strided", dtype => crate::bail!("Powf {dtype:?}"), }; candle_metal_kernels::call_elu_strided( @@ -483,11 +483,11 @@ impl BackendStorage for MetalStorage { // The reduction loop requires the shared array to be properly initialized and for // this we want the number of threads to be a power of two. let (name, check_empty, return_index) = match (op, self.dtype) { - (ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false), - (ReduceOp::Min, DType::F32) => ("fast_min_float", true, false), - (ReduceOp::Max, DType::F32) => ("fast_max_float", true, false), - (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true), - (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true), + (ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false), + (ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false), + (ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false), + (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true), + (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true), _ => crate::bail!("Reduce op for non float"), }; if check_empty && layout.shape().elem_count() == 0 { diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index 18adb457..4166d811 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -109,16 +109,16 @@ kernel void FN_NAME##_strided( \ } \ -AFFINE(affine_float, float) -AFFINE(affine_half, half) -POWF(powf_float, float) -POWF(powf_half, half) -ELU(elu_float, float) -ELU(elu_half, half) +AFFINE(affine_f32, float) +AFFINE(affine_f16, half) +POWF(powf_f32, float) +POWF(powf_f16, half) +ELU(elu_f32, float) +ELU(elu_f16, half) #if __METAL_VERSION__ >= 310 -AFFINE(affine_bfloat, bfloat); -POWF(powf_bfloat, bfloat); -ELU(elu_bfloat, bfloat); +AFFINE(affine_bf16, bfloat); +POWF(powf_bf16, bfloat); +ELU(elu_bf16, bfloat); #endif diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index f18cdbb0..ea21bb34 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -52,11 +52,11 @@ kernel void FN_NAME_STRIDED( \ } #define BINARY_OP(FN, NAME) \ -BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \ -BINARY(FN, half, half, NAME##_half, NAME##_half_strided); +BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \ +BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); #define BFLOAT_BINARY_OP(FN, NAME) \ -BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided); +BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided); BINARY_OP(x + y, add) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 514cf33e..a23aa47c 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -125,16 +125,16 @@ macro_rules! ops{ $( pub mod $name { use super::Kernel; - pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float")); - pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half")); - pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat")); + pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32")); + pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16")); + pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16")); } )+ pub mod copy { use super::Kernel; - pub const FLOAT: Kernel = Kernel("copy_float"); - pub const HALF: Kernel = Kernel("copy_half"); - pub const BFLOAT: Kernel = Kernel("copy_bfloat"); + pub const FLOAT: Kernel = Kernel("copy_f32"); + pub const HALF: Kernel = Kernel("copy_f16"); + pub const BFLOAT: Kernel = Kernel("copy_bf16"); pub const U32: Kernel = Kernel("copy_u32"); pub const U8: Kernel = Kernel("copy_u8"); } @@ -145,16 +145,16 @@ macro_rules! ops{ $( pub mod $name { use super::Kernel; - pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float_strided")); - pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half_strided")); - pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided")); + pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided")); + pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided")); + pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided")); } )+ pub mod copy { use super::Kernel; - pub const FLOAT: Kernel = Kernel("copy_float_strided"); - pub const HALF: Kernel = Kernel("copy_half_strided"); - pub const BFLOAT: Kernel = Kernel("copy_bfloat_strided"); + pub const FLOAT: Kernel = Kernel("copy_f32_strided"); + pub const HALF: Kernel = Kernel("copy_f16_strided"); + pub const BFLOAT: Kernel = Kernel("copy_bf16_strided"); pub const U32: Kernel = Kernel("copy_u32_strided"); pub const U8: Kernel = Kernel("copy_u8_strided"); } diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 3633fdcf..62443660 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -71,9 +71,9 @@ kernel void NAME( \ } \ -REDUCE(x + y, fast_sum_float, float) -REDUCE(x * y, fast_mul_float, float) -REDUCE(max(x, y), fast_max_float, float) +REDUCE(x + y, fast_sum_f32, float) +REDUCE(x * y, fast_mul_f32, float) +REDUCE(max(x, y), fast_max_f32, float) #define SOFTMAX(NAME, T) \ kernel void NAME( \ @@ -142,8 +142,8 @@ kernel void NAME( } \ } \ -SOFTMAX(softmax_float, float) -SOFTMAX(softmax_half, half) +SOFTMAX(softmax_f32, float) +SOFTMAX(softmax_f16, half) #if __METAL_VERSION__ >= 310 -SOFTMAX(softmax_bfloat, bfloat) +SOFTMAX(softmax_bf16, bfloat) #endif diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 765b14a5..553bc506 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -87,11 +87,11 @@ kernel void FN_NAME_STRIDED( \ } #define UNARY_OP(NAME) \ -UNARY(NAME, float, NAME##_float, NAME##_float_strided); \ -UNARY(NAME, half, NAME##_half, NAME##_half_strided); +UNARY(NAME, float, NAME##_f32, NAME##_f32_strided); \ +UNARY(NAME, half, NAME##_f16, NAME##_f16_strided); #define BFLOAT_UNARY_OP(NAME) \ -UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided); +UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided); UNARY_OP(cos) @@ -108,8 +108,8 @@ UNARY_OP(round) UNARY_OP(gelu_erf) UNARY_OP(erf) UNARY_OP(tanh) -UNARY(id, float, copy_float, copy_float_strided) -UNARY(id, half, copy_half, copy_half_strided) +UNARY(id, float, copy_f32, copy_f32_strided) +UNARY(id, half, copy_f16, copy_f16_strided) UNARY(id, uint8_t, copy_u8, copy_u8_strided) UNARY(id, uint32_t, copy_u32, copy_u32_strided) @@ -129,5 +129,5 @@ BFLOAT_UNARY_OP(gelu_erf) BFLOAT_UNARY_OP(erf) BFLOAT_UNARY_OP(tanh) -UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided) +UNARY(id, bfloat, copy_bf16, copy_bf16_strided) #endif diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index f00d8e2f..ca23f90e 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -213,9 +213,9 @@ impl candle::CustomOp1 for SoftmaxLastDim { let command_buffer = device.command_buffer(); let kernels = device.kernels(); let name = match storage.dtype() { - DType::F32 => "softmax_float", - DType::F16 => "softmax_half", - DType::BF16 => "softmax_bfloat", + DType::F32 => "softmax_f32", + DType::F16 => "softmax_f16", + DType::BF16 => "softmax_bf16", dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"), }; -- cgit v1.2.3 From aa040150985e78079bcc05df86266e447c23b4fc Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 15 Dec 2023 12:23:28 +0100 Subject: Remove `unwrap()`. --- candle-core/src/metal_backend.rs | 121 ++++++++++++++++++++++++--------------- candle-nn/src/ops.rs | 4 +- 2 files changed, 77 insertions(+), 48 deletions(-) (limited to 'candle-nn') diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index b4a490cd..f570d2c5 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -8,7 +8,26 @@ use metal; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use std::collections::HashMap; use std::path::Path; -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, RwLock, TryLockError}; + +/// Simple way to catch lock error without +/// depending on T +#[derive(thiserror::Error, Debug)] +pub enum LockError { + #[error("{0}")] + Poisoned(String), + #[error("Would block")] + WouldBlock, +} + +impl From> for MetalError { + fn from(value: TryLockError) -> Self { + match value { + TryLockError::Poisoned(p) => MetalError::LockError(LockError::Poisoned(p.to_string())), + TryLockError::WouldBlock => MetalError::LockError(LockError::WouldBlock), + } + } +} /// Metal related errors #[derive(thiserror::Error, Debug)] @@ -24,6 +43,8 @@ pub enum MetalError { rhs_stride: Vec, mnk: (usize, usize, usize), }, + #[error("{0:?}")] + LockError(LockError), } impl From for MetalError { @@ -106,10 +127,13 @@ impl MetalDevice { &self.command_queue } - pub fn command_buffer(&self) -> CommandBuffer { - let mut command_buffer_lock = self.command_buffer.try_write().unwrap(); + pub fn command_buffer(&self) -> Result { + let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?; let mut command_buffer = command_buffer_lock.to_owned(); - let mut index = self.command_buffer_index.try_write().unwrap(); + let mut index = self + .command_buffer_index + .try_write() + .map_err(MetalError::from)?; if *index > self.compute_per_buffer { command_buffer.commit(); command_buffer = self.command_queue.new_command_buffer().to_owned(); @@ -117,11 +141,11 @@ impl MetalDevice { *index = 0; } *index += 1; - command_buffer + Ok(command_buffer) } - pub fn wait_until_completed(&self) { - let mut command_buffer = self.command_buffer.try_write().unwrap(); + pub fn wait_until_completed(&self) -> Result<()> { + let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?; match command_buffer.status() { metal::MTLCommandBufferStatus::Committed | metal::MTLCommandBufferStatus::Scheduled @@ -133,6 +157,7 @@ impl MetalDevice { command_buffer.commit(); command_buffer.wait_until_completed(); *command_buffer = self.command_queue.new_command_buffer().to_owned(); + Ok(()) } pub fn kernels(&self) -> &Kernels { @@ -148,7 +173,12 @@ impl MetalDevice { /// This means the buffer data cannot be read on the CPU directly. /// /// [`name`] is only used to keep track of the resource origin in case of bugs - pub fn new_buffer(&self, element_count: usize, dtype: DType, name: &str) -> Arc { + pub fn new_buffer( + &self, + element_count: usize, + dtype: DType, + name: &str, + ) -> Result> { let size = (element_count * dtype.size_in_bytes()) as NSUInteger; self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name) } @@ -158,7 +188,7 @@ impl MetalDevice { /// This means the buffer can be read on the CPU but will require manual /// synchronization when the CPU memory is modified /// Used as a bridge to gather data back from the GPU - pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc { + pub fn new_buffer_managed(&self, size: NSUInteger) -> Result> { self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed") } @@ -168,7 +198,7 @@ impl MetalDevice { /// This method will block the computation because of the /// lack of lifetime management through the GPU. /// Internal comment for technical details. - pub fn new_buffer_with_data(&self, data: &[T]) -> Arc { + pub fn new_buffer_with_data(&self, data: &[T]) -> Result> { let size = core::mem::size_of_val(data) as NSUInteger; let tmp = self.device.new_buffer_with_data( data.as_ptr() as *const core::ffi::c_void, @@ -179,8 +209,8 @@ impl MetalDevice { size, metal::MTLResourceOptions::StorageModePrivate, "with_data", - ); - let command_buffer = self.command_buffer(); + )?; + let command_buffer = self.command_buffer()?; command_buffer.set_label("with_data"); let blit = command_buffer.new_blit_command_encoder(); blit.wait_for_fence(&self.fence); @@ -196,8 +226,8 @@ impl MetalDevice { // Putting this wait forces the GPU buffer to be filled // with the actual data allowing the CPU storage todo // deallocate properly. - self.wait_until_completed(); - real + self.wait_until_completed()?; + Ok(real) } /// The critical allocator algorithm @@ -206,13 +236,13 @@ impl MetalDevice { size: NSUInteger, option: MTLResourceOptions, _name: &str, - ) -> Arc { - let mut buffers = self.buffers.try_write().unwrap(); + ) -> Result> { + let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; let subbuffers = buffers.entry((size, option)).or_insert(vec![]); for sub in &mut *subbuffers { if Arc::strong_count(sub) == 1 { - return sub.clone(); + return Ok(sub.clone()); } } let new_buffer = self.device.new_buffer(size as NSUInteger, option); @@ -226,8 +256,7 @@ impl MetalDevice { .collect(); *subbuffers = newbuffers; } - - new_buffer + Ok(new_buffer) } /// Create a metal GPU capture trace on [`path`]. @@ -279,9 +308,9 @@ impl BackendStorage for MetalStorage { self.dtype ); } - let buffer = self.device.new_buffer_managed(self.buffer.length()); + let buffer = self.device.new_buffer_managed(self.buffer.length())?; { - let command_buffer = self.device.command_buffer(); + let command_buffer = self.device.command_buffer()?; command_buffer.set_label("to_cpu"); let blit = command_buffer.new_blit_command_encoder(); blit.set_label("blit_to_cpu"); @@ -290,7 +319,7 @@ impl BackendStorage for MetalStorage { blit.update_fence(&self.device.fence); blit.end_encoding(); } - self.device.wait_until_completed(); + self.device.wait_until_completed()?; match self.dtype { DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))), @@ -310,8 +339,8 @@ impl BackendStorage for MetalStorage { let el = shape.elem_count(); let dtype = self.dtype; - let buffer = device.new_buffer(el, self.dtype, "affine"); - let command_buffer = self.device.command_buffer(); + let buffer = device.new_buffer(el, self.dtype, "affine")?; + let command_buffer = self.device.command_buffer()?; if layout.is_contiguous() && layout.start_offset() == 0 { let name = match self.dtype { DType::F32 => "affine_f32", @@ -361,8 +390,8 @@ impl BackendStorage for MetalStorage { let el = shape.elem_count(); let dtype = self.dtype; - let buffer = device.new_buffer(el, self.dtype, "powf"); - let command_buffer = self.device.command_buffer(); + let buffer = device.new_buffer(el, self.dtype, "powf")?; + let command_buffer = self.device.command_buffer()?; if layout.is_contiguous() && layout.start_offset() == 0 { let name = match self.dtype { DType::F32 => "powf_f32", @@ -410,8 +439,8 @@ impl BackendStorage for MetalStorage { let el = shape.elem_count(); let dtype = self.dtype; - let buffer = device.new_buffer(el, self.dtype, "elu"); - let command_buffer = self.device.command_buffer(); + let buffer = device.new_buffer(el, self.dtype, "elu")?; + let command_buffer = self.device.command_buffer()?; if layout.is_contiguous() && layout.start_offset() == 0 { let name = match self.dtype { DType::F32 => "elu_f32", @@ -497,8 +526,8 @@ impl BackendStorage for MetalStorage { if dtype == DType::U32 { crate::bail!("Implement return index reduce op"); } - let buffer = device.new_buffer(dst_el, dtype, "reduce"); - let command_buffer = self.device.command_buffer(); + let buffer = device.new_buffer(dst_el, dtype, "reduce")?; + let command_buffer = self.device.command_buffer()?; candle_metal_kernels::call_reduce_contiguous( &device.device, &command_buffer, @@ -523,8 +552,8 @@ impl BackendStorage for MetalStorage { let device = self.device(); let shape = layout.shape(); let el_count = shape.elem_count(); - let buffer = device.new_buffer(el_count, dtype, "todtype"); - let command_buffer = device.command_buffer(); + let buffer = device.new_buffer(el_count, dtype, "todtype")?; + let command_buffer = device.command_buffer()?; if layout.is_contiguous() && layout.start_offset() == 0 { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32", @@ -576,8 +605,8 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let shape = layout.shape(); let el_count = shape.elem_count(); - let buffer = device.new_buffer(el_count, dtype, B::KERNEL); - let command_buffer = device.command_buffer(); + let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?; + let command_buffer = device.command_buffer()?; command_buffer.set_label(B::KERNEL); if layout.is_contiguous() && layout.start_offset() == 0 { use candle_metal_kernels::unary::contiguous; @@ -681,8 +710,8 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let shape = lhs_l.shape(); let el_count = shape.elem_count(); - let buffer = device.new_buffer(el_count, dtype, B::KERNEL); - let command_buffer = device.command_buffer(); + let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?; + let command_buffer = device.command_buffer()?; if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) && &B::KERNEL[..1] != "b" @@ -758,8 +787,8 @@ impl BackendStorage for MetalStorage { let dims = shape.dims(); let el = shape.elem_count(); let dtype = t.dtype; - let buffer = self.device.new_buffer(el, dtype, "where"); - let command_buffer = self.device.command_buffer(); + let buffer = self.device.new_buffer(el, dtype, "where")?; + let command_buffer = self.device.command_buffer()?; if t.dtype() != f.dtype() { crate::bail!("Invalid ternary different dtypes for values"); } @@ -875,13 +904,13 @@ impl BackendStorage for MetalStorage { let dst_el = ids_el * left_size * right_size; let dtype = self.dtype; let device = self.device(); - let buffer = device.new_buffer(dst_el, dtype, "index_select"); + let buffer = device.new_buffer(dst_el, dtype, "index_select")?; let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "is_u32_f32", (DType::U32, DType::F16) => "is_u32_f16", (left, right) => crate::bail!("index select metal {left:?} {right:?}"), }; - let command_buffer = self.device.command_buffer(); + let command_buffer = self.device.command_buffer()?; candle_metal_kernels::call_index_select( &device.device, &command_buffer, @@ -916,7 +945,7 @@ impl BackendStorage for MetalStorage { lhs_l: &Layout, rhs_l: &Layout, ) -> Result { - let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul"); + let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?; let name = match self.dtype { DType::F32 => "sgemm", DType::F16 => "hgemm", @@ -925,7 +954,7 @@ impl BackendStorage for MetalStorage { } }; - let command_buffer = self.device.command_buffer(); + let command_buffer = self.device.command_buffer()?; command_buffer.set_label("matmul"); candle_metal_kernels::call_gemm( &self.device.device, @@ -946,7 +975,7 @@ impl BackendStorage for MetalStorage { } fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { - let command_buffer = self.device.command_buffer(); + let command_buffer = self.device.command_buffer()?; if src_l.is_contiguous() && self.dtype == dst.dtype() { command_buffer.set_label("copy_contiguous"); let blit = command_buffer.new_blit_command_encoder(); @@ -1047,8 +1076,8 @@ impl BackendDevice for MetalDevice { } fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { - let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros"); - let command_buffer = self.command_buffer(); + let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros")?; + let command_buffer = self.command_buffer()?; command_buffer.set_label("zeros"); let blit = command_buffer.new_blit_command_encoder(); blit.wait_for_fence(&self.fence); @@ -1080,7 +1109,7 @@ impl BackendDevice for MetalDevice { CpuStorage::F16(storage) => self.new_buffer_with_data(storage), CpuStorage::F32(storage) => self.new_buffer_with_data(storage), CpuStorage::F64(storage) => self.new_buffer_with_data(storage), - }; + }?; Ok(Self::Storage::new( buffer.into(), self.clone(), diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index ca23f90e..94380f12 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -210,7 +210,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { ) -> Result<(candle::MetalStorage, Shape)> { use candle::{backend::BackendStorage, DType}; let device = storage.device(); - let command_buffer = device.command_buffer(); + let command_buffer = device.command_buffer()?; let kernels = device.kernels(); let name = match storage.dtype() { DType::F32 => "softmax_f32", @@ -226,7 +226,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { let last_dim = layout.dims()[layout.shape().rank() - 1]; let elem_count = layout.shape().elem_count(); - let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax"); + let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax")?; candle_metal_kernels::call_last_softmax( device.metal_device(), &command_buffer, -- cgit v1.2.3 From 6bc92e63cb4d1a3bb4910348861b9f7e80dfa4f0 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 15 Dec 2023 13:06:04 +0100 Subject: Addressing a lot of comments. --- candle-core/src/metal_backend.rs | 23 +++++++++++++++-------- candle-metal-kernels/src/lib.rs | 6 +++++- candle-metal-kernels/src/tests.rs | 21 +++++++++++---------- candle-nn/src/ops.rs | 3 ++- 4 files changed, 33 insertions(+), 20 deletions(-) (limited to 'candle-nn') diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index f570d2c5..424b29d9 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -482,11 +482,14 @@ impl BackendStorage for MetalStorage { } fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { - if !(sum_dims.len() == 1 - && sum_dims[0] == layout.shape().rank() - 1 - && layout.stride()[sum_dims[0]] == 1) - { - crate::bail!("Non last dim reduce op not supported yet"); + if sum_dims.len() != 1 { + crate::bail!("reduce {op:?} over multiple dimensions is not implemented yet."); + } + if sum_dims[0] != layout.shape().rank() - 1 { + crate::bail!("Non last dim reduce op {op:?} not implemented yet"); + } + if layout.stride()[sum_dims[0]] != 1 { + crate::bail!("Non contiguous reduce op {op:?} not implemented yet"); } let device = self.device.clone(); @@ -524,7 +527,7 @@ impl BackendStorage for MetalStorage { } let dtype = if return_index { DType::U32 } else { self.dtype }; if dtype == DType::U32 { - crate::bail!("Implement return index reduce op"); + crate::bail!("reduce op {name} is not implemented yet."); } let buffer = device.new_buffer(dst_el, dtype, "reduce")?; let command_buffer = self.device.command_buffer()?; @@ -790,12 +793,16 @@ impl BackendStorage for MetalStorage { let buffer = self.device.new_buffer(el, dtype, "where")?; let command_buffer = self.device.command_buffer()?; if t.dtype() != f.dtype() { - crate::bail!("Invalid ternary different dtypes for values"); + crate::bail!( + "Invalid where: different dtypes for values {:?} != {:?}", + t.dtype(), + f.dtype() + ); } let name = match (self.dtype, t.dtype()) { (DType::U8, DType::F32) => "where_u8_f32", (DType::U8, DType::F16) => "where_u8_f16", - (left, right) => crate::bail!("Ternary {left:?} - {right:?} not implemented"), + (left, right) => crate::bail!("where {left:?} - {right:?} not implemented"), }; candle_metal_kernels::call_where_cond_strided( &device.device, diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index a23aa47c..f2db171e 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -597,6 +597,7 @@ pub fn call_last_softmax( length: usize, elements_to_sum: usize, input: &Buffer, + input_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; @@ -604,7 +605,10 @@ pub fn call_last_softmax( encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, elements_to_sum, input, output)); + set_params!( + encoder, + (length, elements_to_sum, (input, input_offset), output) + ); let out_length = length / elements_to_sum; diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 75c2f013..9c9475a2 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -312,7 +312,7 @@ fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { &device, command_buffer, &kernels, - "affine_float", + "affine_f32", size, &input, &output, @@ -346,7 +346,7 @@ fn run_affine_strided( &device, command_buffer, &kernels, - "affine_float_strided", + "affine_f32_strided", shape, &input, strides, @@ -608,6 +608,7 @@ fn run_softmax(v: &[T], last_dim: usize, name: &'sta v.len(), last_dim, &input, + 0, &output, ) .unwrap(); @@ -622,7 +623,7 @@ fn reduce_sum() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let out_length = 1; - let results = run_reduce(&v, out_length, "fast_sum_float"); + let results = run_reduce(&v, out_length, "fast_sum_f32"); assert_eq!(approx(results, 4), vec![21.0]); } @@ -631,7 +632,7 @@ fn reduce_sum2() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let out_length = 2; - let results = run_reduce(&v, out_length, "fast_sum_float"); + let results = run_reduce(&v, out_length, "fast_sum_f32"); assert_eq!(approx(results, 4), vec![6.0, 15.0]); } @@ -639,7 +640,7 @@ fn reduce_sum2() { fn softmax() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let last_dim = 6; - let results = run_softmax(&v, last_dim, "softmax_float"); + let results = run_softmax(&v, last_dim, "softmax_f32"); assert_eq!( approx(results, 4), vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] @@ -651,7 +652,7 @@ fn softmax() { for i in 0..n { v[i * last_dim] = 20.0; } - let results = run_softmax(&v, last_dim, "softmax_float"); + let results = run_softmax(&v, last_dim, "softmax_f32"); let results = approx(results, 4); println!("{results:?}"); assert_eq!( @@ -665,7 +666,7 @@ fn softmax() { let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]; let last_dim = 6; - let results = run_softmax(&v, last_dim, "softmax_float"); + let results = run_softmax(&v, last_dim, "softmax_f32"); assert_eq!( approx(results, 4), vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] @@ -673,7 +674,7 @@ fn softmax() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let last_dim = 3; - let results = run_softmax(&v, last_dim, "softmax_float"); + let results = run_softmax(&v, last_dim, "softmax_f32"); assert_eq!( approx(results, 4), vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652] @@ -684,7 +685,7 @@ fn softmax() { .map(|v| f16::from_f32(*v)) .collect::>(); let last_dim = 6; - let results = run_softmax(&v, last_dim, "softmax_half"); + let results = run_softmax(&v, last_dim, "softmax_f16"); assert_eq!( approx_f16(results, 4), vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338] @@ -695,7 +696,7 @@ fn softmax() { .map(|v| bf16::from_f32(*v)) .collect::>(); let last_dim = 6; - let results = run_softmax(&v, last_dim, "softmax_bfloat"); + let results = run_softmax(&v, last_dim, "softmax_bf16"); assert_eq!( approx_bf16(results, 4), vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328] diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 94380f12..816eff42 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -220,7 +220,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { }; let n = layout.stride().len(); - if !(layout.is_contiguous() && layout.stride()[n - 1] == 1 && layout.start_offset() == 0) { + if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) { candle::bail!("Non contiguous softmax-last-dim is not implemented"); } @@ -235,6 +235,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { elem_count, last_dim, storage.buffer(), + layout.start_offset() * storage.dtype().size_in_bytes(), &mut output, ) .unwrap(); -- cgit v1.2.3 From 03641293eeb1dd0ff3d5a93e85c7f9eb289704e4 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 18 Dec 2023 15:22:43 +0100 Subject: Clippy pass. --- candle-core/src/metal_backend.rs | 18 ++++++++---------- candle-metal-kernels/src/tests.rs | 1 - candle-nn/src/ops.rs | 6 +++--- 3 files changed, 11 insertions(+), 14 deletions(-) (limited to 'candle-nn') diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 0af11a3d..27b2824f 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -59,6 +59,8 @@ impl From for MetalError { } } +type AllocatedBuffers = Arc>>>>; + #[derive(Clone)] pub struct MetalDevice { /// Raw metal device: @@ -103,7 +105,7 @@ pub struct MetalDevice { /// /// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers /// (strong_count = 1). - buffers: Arc>>>>, + buffers: AllocatedBuffers, } impl std::fmt::Debug for MetalDevice { @@ -258,7 +260,7 @@ impl MetalDevice { let newbuffers = subbuffers .iter() .filter(|s| Arc::strong_count(s) > 1) - .map(|s| Arc::clone(s)) + .map(Arc::clone) .collect(); *subbuffers = newbuffers; } @@ -270,7 +272,7 @@ impl MetalDevice { let capture = metal::CaptureManager::shared(); let descriptor = metal::CaptureDescriptor::new(); descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); - descriptor.set_capture_device(&self); + descriptor.set_capture_device(self); descriptor.set_output_url(path); capture @@ -1021,10 +1023,10 @@ impl BackendStorage for MetalStorage { &self.device.kernels, name, (b, m, n, k), - &lhs_l.stride(), + lhs_l.stride(), lhs_l.start_offset() * self.dtype.size_in_bytes(), &self.buffer, - &rhs_l.stride(), + rhs_l.stride(), rhs_l.start_offset() * rhs.dtype.size_in_bytes(), &rhs.buffer, &buffer, @@ -1274,11 +1276,7 @@ impl BackendDevice for MetalDevice { CpuStorage::F32(storage) => self.new_buffer_with_data(storage), CpuStorage::F64(storage) => self.new_buffer_with_data(storage), }?; - Ok(Self::Storage::new( - buffer.into(), - self.clone(), - storage.dtype(), - )) + Ok(Self::Storage::new(buffer, self.clone(), storage.dtype())) } fn rand_uniform( diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 8d5a2624..1b3153b1 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -574,7 +574,6 @@ fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec()) as u64, options); - let num_dims = 1; let dims = vec![v.len()]; let strides = vec![1]; call_reduce_strided( diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 816eff42..abe33350 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -226,17 +226,17 @@ impl candle::CustomOp1 for SoftmaxLastDim { let last_dim = layout.dims()[layout.shape().rank() - 1]; let elem_count = layout.shape().elem_count(); - let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax")?; + let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?; candle_metal_kernels::call_last_softmax( device.metal_device(), &command_buffer, - &kernels, + kernels, name, elem_count, last_dim, storage.buffer(), layout.start_offset() * storage.dtype().size_in_bytes(), - &mut output, + &output, ) .unwrap(); let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype()); -- cgit v1.2.3