diff options
author | nicolas <nicolas@nicolass-MacBook-Pro.local> | 2023-12-12 17:41:56 +0100 |
---|---|---|
committer | nicolas <nicolas@nicolass-MacBook-Pro.local> | 2023-12-12 17:41:56 +0100 |
commit | 87dc559817db11f8d8c409cda959528e57e1db31 (patch) | |
tree | 3f7ec04a0facab3378158ae3ba84416d56fd37a7 | |
parent | da0af3cb3e58d38476a20f4465744093a3b75dd4 (diff) | |
download | candle-87dc559817db11f8d8c409cda959528e57e1db31.tar.gz candle-87dc559817db11f8d8c409cda959528e57e1db31.tar.bz2 candle-87dc559817db11f8d8c409cda959528e57e1db31.zip |
Lots of updates including some stack of command buffers.
-rw-r--r-- | candle-core/src/metal_backend.rs | 389 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 2 | ||||
-rw-r--r-- | candle-metal-kernels/src/affine.metal | 75 | ||||
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 126 | ||||
-rw-r--r-- | candle-metal-kernels/src/reduce.metal | 2 | ||||
-rw-r--r-- | candle-metal-kernels/src/unary.metal | 6 | ||||
-rw-r--r-- | candle-nn/Cargo.toml | 3 | ||||
-rw-r--r-- | candle-nn/src/ops.rs | 4 | ||||
-rw-r--r-- | candle-transformers/Cargo.toml | 1 | ||||
-rw-r--r-- | candle-transformers/src/models/mixformer.rs | 46 |
10 files changed, 537 insertions, 117 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 12f56d50..4354422c 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -38,7 +38,8 @@ impl From<String> for MetalError { pub struct MetalDevice { device: metal::Device, command_queue: metal::CommandQueue, - command_buffer: Arc<RwLock<metal::CommandBuffer>>, + command_buffers: Arc<RwLock<Vec<metal::CommandBuffer>>>, + command_buffer_index: Arc<RwLock<usize>>, kernels: Arc<candle_metal_kernels::Kernels>, buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>, } @@ -70,38 +71,69 @@ impl MetalDevice { &self.command_queue } - pub fn command_buffer(&self) -> std::sync::RwLockReadGuard<CommandBuffer> { - self.command_buffer.try_read().unwrap() - } - - pub fn commit(&self) { - let mut old = self.command_buffer.try_write().unwrap(); - match old.status() { - metal::MTLCommandBufferStatus::NotEnqueued - | metal::MTLCommandBufferStatus::Enqueued => { - old.commit(); - let command_buffer = self.command_queue.new_command_buffer().to_owned(); - *old = command_buffer; + pub fn command_buffer(&self) -> CommandBuffer { + let mut command_buffers = self.command_buffers.try_write().unwrap(); + let mut index = self.command_buffer_index.try_write().unwrap(); + let n = command_buffers.len(); + if *index == n { + // todo!("Cycle buffers"); + for i in 0..n { + let command_buffer = &command_buffers[i]; + match command_buffer.status() { + metal::MTLCommandBufferStatus::Committed + | metal::MTLCommandBufferStatus::Scheduled => { + // println!("Wait during cycling {i}"); + // println!("Command {i} / {n}: {:?}", command_buffer.status()); + command_buffer.wait_until_completed(); + } + metal::MTLCommandBufferStatus::Completed => {} + _ => { + panic!("Command buffer {i} not committed during cycling"); + } + } } - _ => {} + let new_buffers = (0..n) + .map(|i| { + // println!("Creating command buffer {i}"); + let command_buffer = self.command_queue.new_command_buffer().to_owned(); + command_buffer.enqueue(); + command_buffer + }) + .collect(); + *command_buffers = new_buffers; + *index = 0; + // println!("Reset"); } + // println!("Giving buffer {} / {n}", *index); + let out = &command_buffers[*index]; + assert_eq!(out.status(), metal::MTLCommandBufferStatus::Enqueued); + *index += 1; + out.to_owned() } pub fn wait_until_completed(&self) { - let mut old = self.command_buffer.try_write().unwrap(); - match old.status() { - metal::MTLCommandBufferStatus::NotEnqueued - | metal::MTLCommandBufferStatus::Enqueued => { - old.commit(); - old.wait_until_completed(); + let command_buffers = self.command_buffers.try_write().unwrap(); + let index = self.command_buffer_index.try_write().unwrap(); + let n = command_buffers.len(); + // for i in 0..*index { + // let command_buffer = &command_buffers[i]; + // println!("Command {i} / {n}: {:?}", command_buffer.status()); + // } + for i in 0..*index { + let command_buffer = &command_buffers[i]; + match command_buffer.status() { + metal::MTLCommandBufferStatus::Committed + | metal::MTLCommandBufferStatus::Scheduled => {} + metal::MTLCommandBufferStatus::Completed => {} + _ => { + panic!("Command buffer not committed"); + } } - metal::MTLCommandBufferStatus::Committed | metal::MTLCommandBufferStatus::Scheduled => { - old.wait_until_completed(); - } - _ => {} + // println!("Wait {i}"); + command_buffer.wait_until_completed(); + // println!("Ok {i}"); + // command_buffer.wait_until_completed(); } - let command_buffer = self.command_queue.new_command_buffer().to_owned(); - *old = command_buffer; } pub fn kernels(&self) -> &Kernels { @@ -112,28 +144,40 @@ impl MetalDevice { &self.device } - pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Arc<Buffer> { + pub fn new_buffer(&self, element_count: usize, dtype: DType, name: &str) -> Arc<Buffer> { let size = (element_count * dtype.size_in_bytes()) as NSUInteger; - self._new_buffer(size, MTLResourceOptions::StorageModePrivate) + self._new_buffer(size, MTLResourceOptions::StorageModePrivate, name) } - fn _new_buffer(&self, size: NSUInteger, option: MTLResourceOptions) -> Arc<Buffer> { + fn _new_buffer(&self, size: NSUInteger, option: MTLResourceOptions, name: &str) -> Arc<Buffer> { + // println!("Creating new buffer {name}"); let mut buffers = self.buffers.try_write().unwrap(); let subbuffers = buffers.entry((size, option)).or_insert(vec![]); for sub in &mut *subbuffers { if Arc::strong_count(sub) == 1 { - return sub.clone(); + // println!("Reusing tensor {size} {name}"); + // return sub.clone(); } } let new_buffer = self.device.new_buffer(size as NSUInteger, option); let new_buffer = Arc::new(new_buffer); - subbuffers.push(new_buffer.clone()); + // subbuffers.push(new_buffer.clone()); + // println!("Created tensor {size} {name}"); + for subbuffers in buffers.values_mut() { + let newbuffers = subbuffers + .iter() + .filter(|s| Arc::strong_count(s) > 1) + .map(|s| Arc::clone(s)) + .collect(); + *subbuffers = newbuffers; + } + new_buffer } pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> { - self._new_buffer(size, MTLResourceOptions::StorageModeManaged) + self._new_buffer(size, MTLResourceOptions::StorageModeManaged, "managed") } pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> { @@ -143,13 +187,20 @@ impl MetalDevice { size, metal::MTLResourceOptions::StorageModeManaged, ); - let real = self._new_buffer(size, metal::MTLResourceOptions::StorageModePrivate); - { - let command = self.command_buffer(); - let blit = command.new_blit_command_encoder(); - blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); - blit.end_encoding(); - } + let real = self._new_buffer( + size, + metal::MTLResourceOptions::StorageModePrivate, + "with_data", + ); + let command = self.command_buffer(); + let blit = command.new_blit_command_encoder(); + blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); + blit.end_encoding(); + command.commit(); + real.did_modify_range(metal::NSRange::new(0, real.length())); + // println!("Command {:?}", command.status()); + + // self.commit(); // This is necessary, for mmaped safetensors // Because of the unsafe slice cast we're doing. // The slice might not live long enough for metal @@ -169,7 +220,7 @@ impl MetalDevice { dtype: DType, ) -> Result<(Matrix, Arc<Buffer>)> { let elem_count = (b * m * n) as usize; - let out_buffer = self.new_buffer(elem_count, dtype); + let out_buffer = self.new_buffer(elem_count, dtype, "matrix"); let result_descriptor = MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id); @@ -241,13 +292,18 @@ impl BackendStorage for MetalStorage { self.dtype ); } - + self.device.wait_until_completed(); + self.buffer + .did_modify_range(metal::NSRange::new(0, self.buffer.length())); let buffer = self.device.new_buffer_managed(self.buffer.length()); - let command_buffer = self.device.command_buffer(); - let blit = command_buffer.new_blit_command_encoder(); - blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); - blit.end_encoding(); - drop(command_buffer); + { + let command_buffer = self.device.command_buffer(); + let blit = command_buffer.new_blit_command_encoder(); + blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); + blit.end_encoding(); + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); + } self.device.wait_until_completed(); match self.dtype { @@ -256,7 +312,11 @@ impl BackendStorage for MetalStorage { DType::I64 => Ok(CpuStorage::I64(buffer.read_to_vec(length / size))), DType::F16 => Ok(CpuStorage::F16(buffer.read_to_vec(length / size))), DType::BF16 => Ok(CpuStorage::BF16(buffer.read_to_vec(length / size))), - DType::F32 => Ok(CpuStorage::F32(buffer.read_to_vec(length / size))), + DType::F32 => { + let vec = buffer.read_to_vec(length / size); + // println!("Got back {:?}", &vec[..1]); + Ok(CpuStorage::F32(vec)) + } DType::F64 => Ok(CpuStorage::F64(buffer.read_to_vec(length / size))), } } @@ -268,7 +328,7 @@ impl BackendStorage for MetalStorage { let el = shape.elem_count(); let dtype = self.dtype; - let buffer = device.new_buffer(el, self.dtype); + let buffer = device.new_buffer(el, self.dtype, "affine"); let command_buffer = self.device.command_buffer(); if layout.is_contiguous() && layout.start_offset() == 0 { let name = match self.dtype { @@ -309,15 +369,111 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } - fn powf(&self, _: &Layout, _: f64) -> Result<Self> { - crate::bail!("powf metal") + fn powf(&self, layout: &Layout, pow: f64) -> Result<Self> { + let device = self.device().clone(); + + let shape = layout.shape(); + let el = shape.elem_count(); + let dtype = self.dtype; + + let buffer = device.new_buffer(el, self.dtype, "powf"); + let command_buffer = self.device.command_buffer(); + if layout.is_contiguous() && layout.start_offset() == 0 { + let name = match self.dtype { + DType::F32 => "powf_float", + DType::F16 => "powf_half", + dtype => crate::bail!("Powf {dtype:?}"), + }; + candle_metal_kernels::call_powf( + &device.device, + &command_buffer, + &device.kernels, + name, + el, + &self.buffer, + &buffer, + pow as f32, + ) + .map_err(MetalError::from)?; + } else { + let name = match self.dtype { + DType::F32 => "powf_float_strided", + DType::F16 => "powf_half_strided", + dtype => crate::bail!("Powf {dtype:?}"), + }; + candle_metal_kernels::call_powf_strided( + &device.device, + &command_buffer, + &device.kernels, + name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * dtype.size_in_bytes(), + &buffer, + pow as f32, + ) + .map_err(MetalError::from)?; + } + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); + Ok(Self::new(buffer, device.clone(), dtype)) } - fn elu(&self, _: &Layout, _: f64) -> Result<Self> { - crate::bail!("elu metal") + fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> { + let device = self.device().clone(); + + let shape = layout.shape(); + let el = shape.elem_count(); + let dtype = self.dtype; + + let buffer = device.new_buffer(el, self.dtype, "elu"); + let command_buffer = self.device.command_buffer(); + if layout.is_contiguous() && layout.start_offset() == 0 { + let name = match self.dtype { + DType::F32 => "elu_float", + DType::F16 => "elu_half", + dtype => crate::bail!("Powf {dtype:?}"), + }; + candle_metal_kernels::call_elu( + &device.device, + &command_buffer, + &device.kernels, + name, + el, + &self.buffer, + &buffer, + alpha as f32, + ) + .map_err(MetalError::from)?; + } else { + let name = match self.dtype { + DType::F32 => "elu_float_strided", + DType::F16 => "elu_half_strided", + dtype => crate::bail!("Powf {dtype:?}"), + }; + candle_metal_kernels::call_elu_strided( + &device.device, + &command_buffer, + &device.kernels, + name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * dtype.size_in_bytes(), + &buffer, + alpha as f32, + ) + .map_err(MetalError::from)?; + } + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); + Ok(Self::new(buffer, device.clone(), dtype)) } fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> { @@ -365,7 +521,7 @@ impl BackendStorage for MetalStorage { if dtype == DType::U32 { crate::bail!("Implement return index reduce op"); } - let buffer = device.new_buffer(dst_el, dtype); + let buffer = device.new_buffer(dst_el, dtype, "reduce"); let command_buffer = self.device.command_buffer(); candle_metal_kernels::call_reduce_contiguous( &device.device, @@ -379,6 +535,8 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device, dtype)) } @@ -391,7 +549,7 @@ impl BackendStorage for MetalStorage { let device = self.device(); let shape = layout.shape(); let el_count = shape.elem_count(); - let buffer = device.new_buffer(el_count, dtype); + let buffer = device.new_buffer(el_count, dtype, "todtype"); let command_buffer = device.command_buffer(); if layout.is_contiguous() { let kernel_name = match (self.dtype, dtype) { @@ -435,6 +593,8 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -444,7 +604,7 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let shape = layout.shape(); let el_count = shape.elem_count(); - let buffer = device.new_buffer(el_count, dtype); + let buffer = device.new_buffer(el_count, dtype, B::KERNEL); let command_buffer = device.command_buffer(); if layout.is_contiguous() && layout.start_offset() == 0 { use candle_metal_kernels::unary::contiguous; @@ -463,6 +623,7 @@ impl BackendStorage for MetalStorage { ("uceil", DType::F32) => contiguous::ceil::FLOAT, ("ufloor", DType::F32) => contiguous::floor::FLOAT, ("uround", DType::F32) => contiguous::round::FLOAT, + ("utanh", DType::F32) => contiguous::tanh::FLOAT, ("ucos", DType::F16) => contiguous::cos::HALF, ("usin", DType::F16) => contiguous::sin::HALF, ("usqr", DType::F16) => contiguous::sqr::HALF, @@ -476,6 +637,7 @@ impl BackendStorage for MetalStorage { ("uceil", DType::F16) => contiguous::ceil::HALF, ("ufloor", DType::F16) => contiguous::floor::HALF, ("uround", DType::F16) => contiguous::round::HALF, + ("utanh", DType::F16) => contiguous::tanh::HALF, (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_unary_contiguous( @@ -534,8 +696,8 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } command_buffer.set_label("unary"); - drop(command_buffer); - self.device.commit(); + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -549,30 +711,31 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let shape = lhs_l.shape(); let el_count = shape.elem_count(); - let buffer = device.new_buffer(el_count, dtype); + let buffer = device.new_buffer(el_count, dtype, B::KERNEL); let command_buffer = device.command_buffer(); if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) + && &B::KERNEL[..1] != "b" { use candle_metal_kernels::binary::contiguous; let kernel_name = match (B::KERNEL, dtype) { ("add", DType::F32) => contiguous::add::FLOAT, - ("badd", DType::F32) => contiguous::add::FLOAT, + // ("badd", DType::F32) => contiguous::add::FLOAT, ("sub", DType::F32) => contiguous::sub::FLOAT, - ("bsub", DType::F32) => contiguous::sub::FLOAT, + //("bsub", DType::F32) => contiguous::sub::FLOAT, ("mul", DType::F32) => contiguous::mul::FLOAT, - ("bmul", DType::F32) => contiguous::mul::FLOAT, + // ("bmul", DType::F32) => contiguous::mul::FLOAT, ("div", DType::F32) => contiguous::div::FLOAT, - ("bdiv", DType::F32) => contiguous::div::FLOAT, + // ("bdiv", DType::F32) => contiguous::div::FLOAT, ("add", DType::F16) => contiguous::add::HALF, - ("badd", DType::F16) => contiguous::add::HALF, + // ("badd", DType::F16) => contiguous::add::HALF, ("sub", DType::F16) => contiguous::sub::HALF, - ("bsub", DType::F16) => contiguous::sub::HALF, + // ("bsub", DType::F16) => contiguous::sub::HALF, ("mul", DType::F16) => contiguous::mul::HALF, - ("bmul", DType::F16) => contiguous::mul::HALF, + // ("bmul", DType::F16) => contiguous::mul::HALF, ("div", DType::F16) => contiguous::div::HALF, - ("bdiv", DType::F16) => contiguous::div::HALF, + // ("bdiv", DType::F16) => contiguous::div::HALF, (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_binary_contiguous( @@ -617,8 +780,8 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } command_buffer.set_label("binary"); - drop(command_buffer); - self.device.commit(); + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -635,7 +798,7 @@ impl BackendStorage for MetalStorage { let dims = shape.dims(); let el = shape.elem_count(); let dtype = t.dtype; - let buffer = self.device.new_buffer(el, dtype); + let buffer = self.device.new_buffer(el, dtype, "where"); let command_buffer = self.device.command_buffer(); if t.dtype() != f.dtype() { crate::bail!("Invalid ternary different dtypes for values"); @@ -663,6 +826,8 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device, dtype)) } @@ -752,7 +917,7 @@ impl BackendStorage for MetalStorage { let dst_el = ids_el * left_size * right_size; let dtype = self.dtype; let device = self.device(); - let buffer = device.new_buffer(dst_el, dtype); + let buffer = device.new_buffer(dst_el, dtype, "index_select"); let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "is_u32_f32", (DType::U32, DType::F16) => "is_u32_f16", @@ -772,6 +937,8 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -887,9 +1054,9 @@ impl BackendStorage for MetalStorage { &result_matrix, ); command_buffer.set_label("matmul"); - drop(command_buffer); - self.device.commit(); - + command_buffer.commit(); + out_buffer.did_modify_range(metal::NSRange::new(0, out_buffer.length())); + // println!("========= MATMUL {:?}", Arc::strong_count(&out_buffer)); Ok(Self::new(out_buffer, self.device.clone(), self.dtype())) } @@ -899,14 +1066,9 @@ impl BackendStorage for MetalStorage { command_buffer.set_label("copy_contiguous"); let blit = command_buffer.new_blit_command_encoder(); let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger; + let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger; let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger; - blit.copy_from_buffer( - &self.buffer, - src_offset, - dst.buffer(), - dst_offset, - self.buffer.length() - src_offset, - ); + blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length); blit.end_encoding(); } else { let src_shape = src_l.shape(); @@ -937,8 +1099,9 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; command_buffer.set_label("copy_strided"); } - drop(command_buffer); - self.device.commit(); + command_buffer.commit(); + dst.buffer + .did_modify_range(metal::NSRange::new(0, dst.buffer.length())); Ok(()) } } @@ -968,22 +1131,22 @@ impl MetalStorage { ) -> Result<Matrix> { let key = (b, m, n, transpose, size, offset, type_id); - let mut matrices = self.matrices.try_write().unwrap(); - if let Some(matrix) = matrices.get(&key) { - Ok(matrix.clone()) + // let mut matrices = self.matrices.try_write().unwrap(); + // if let Some(matrix) = matrices.get(&key) { + // Ok(matrix.clone()) + // } else { + let descriptor = if transpose { + MatrixDescriptor::init_multiple(n, m, b, m * size, m * n * size, type_id) } else { - let descriptor = if transpose { - MatrixDescriptor::init_multiple(n, m, b, m * size, m * n * size, type_id) - } else { - MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id) - }; - let matrix = Matrix::init_with_buffer_descriptor(&self.buffer, offset, &descriptor) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; - matrices.insert(key, matrix.clone()); - Ok(matrix) - } + MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id) + }; + let matrix = Matrix::init_with_buffer_descriptor(&self.buffer, offset, &descriptor) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; + // matrices.insert(key, matrix.clone()); + Ok(matrix) + // } } } @@ -991,16 +1154,28 @@ impl BackendDevice for MetalDevice { type Storage = MetalStorage; fn new(ordinal: usize) -> Result<Self> { + // println!("CREATING DEVICE"); let device = metal::Device::all().swap_remove(ordinal); + let n = 50; let command_queue = device.new_command_queue(); - let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned())); + + let command_buffers = (0..n) + .map(|_| { + let command_buffer = command_queue.new_command_buffer().to_owned(); + command_buffer.enqueue(); + command_buffer + }) + .collect(); + let command_buffers = Arc::new(RwLock::new(command_buffers)); + let command_buffer_index = Arc::new(RwLock::new(0)); let kernels = Arc::new(Kernels::new()); let buffers = Arc::new(RwLock::new(HashMap::new())); Ok(Self { device, command_queue, - command_buffer, + command_buffers, + command_buffer_index, buffers, kernels, }) @@ -1021,7 +1196,21 @@ impl BackendDevice for MetalDevice { } fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> { - let buffer = self.new_buffer(shape.elem_count(), dtype); + let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros"); + let command_buffer = self.command_buffer(); + let blit = command_buffer.new_blit_command_encoder(); + blit.fill_buffer( + &buffer, + metal::NSRange { + location: 0, + length: buffer.length(), + }, + 0, + ); + blit.end_encoding(); + + command_buffer.commit(); + buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(MetalStorage::new(buffer, self.clone(), dtype)) } diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 87323a84..73a0cc7a 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1864,7 +1864,7 @@ impl Tensor { } (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), (Storage::Metal(storage), Device::Cpu) => { - println!("{storage:?} - {:?}", storage.to_cpu_storage()?); + // println!("{storage:?} - {:?}", storage.to_cpu_storage()?); Storage::Cpu(storage.to_cpu_storage()?) } (Storage::Cuda(storage), Device::Cuda(cuda)) => { diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index a08bfbc0..18adb457 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -29,9 +29,7 @@ kernel void FN_NAME( \ if (id >= dim) { \ return; \ } \ - const TYPENAME m = TYPENAME(mul); \ - const TYPENAME a = TYPENAME(add); \ - output[id] = input[id] * m + a; \ + output[id] = TYPENAME(float(input[id]) * mul + add); \ } \ kernel void FN_NAME##_strided( \ constant size_t &dim, \ @@ -47,15 +45,80 @@ kernel void FN_NAME##_strided( \ if (id >= dim) { \ return; \ } \ - const TYPENAME m = TYPENAME(mul); \ - const TYPENAME a = TYPENAME(add); \ - output[id] = input[get_strided_index(id, num_dims, dims, strides)] * m + a; \ + output[id] = TYPENAME(float(input[get_strided_index(id, num_dims, dims, strides)]) * mul + add); \ +} + +#define POWF(FN_NAME, TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + constant float &mul, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + output[id] = TYPENAME(pow(input[id], TYPENAME(mul))); \ } \ +kernel void FN_NAME##_strided( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant float &mul, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + output[id] = TYPENAME(pow(input[get_strided_index(id, num_dims, dims, strides)], TYPENAME(mul))); \ +} + +#define ELU(FN_NAME, TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + constant float &mul, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + const TYPENAME x = input[id]; \ + output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \ +} \ +kernel void FN_NAME##_strided( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant float &mul, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + const TYPENAME x = input[get_strided_index(id, num_dims, dims, strides)]; \ + output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \ +} \ + AFFINE(affine_float, float) AFFINE(affine_half, half) +POWF(powf_float, float) +POWF(powf_half, half) +ELU(elu_float, float) +ELU(elu_half, half) #if __METAL_VERSION__ >= 310 AFFINE(affine_bfloat, bfloat); +POWF(powf_bfloat, bfloat); +ELU(elu_bfloat, bfloat); #endif diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index a0b852a4..237bd858 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -153,7 +153,7 @@ macro_rules! ops{ } pub mod unary { - ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf); + ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh); } pub mod binary { ops!(add, sub, mul, div); @@ -616,6 +616,130 @@ pub fn call_affine_strided( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_powf( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + size: usize, + input: &Buffer, + output: &Buffer, + mul: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (size, mul, input, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_powf_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + input: &Buffer, + input_stride: &[usize], + input_offset: usize, + output: &Buffer, + mul: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + let size: usize = shape.iter().product(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + size, + shape.len(), + shape, + input_stride, + mul, + (input, input_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_elu( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + size: usize, + input: &Buffer, + output: &Buffer, + mul: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (size, mul, input, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_elu_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + input: &Buffer, + input_stride: &[usize], + input_offset: usize, + output: &Buffer, + mul: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + let size: usize = shape.iter().product(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + size, + shape.len(), + shape, + input_stride, + mul, + (input, input_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + pub fn call_where_cond_strided( device: &Device, command_buffer: &CommandBufferRef, diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 867877fb..3a402427 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -18,7 +18,7 @@ METAL_FUNC uint get_strided_index( return strided_i; } -constant int THREADGROUP_SIZE = 1024; +constant int THREADGROUP_SIZE = 2048; # define REDUCE(FN, NAME, T) \ kernel void NAME( \ diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 529162bd..765b14a5 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -69,7 +69,7 @@ kernel void FN_NAME( \ if (thread_position_in_grid >= dim) { \ return; \ } \ - output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \ + output[thread_position_in_grid] = TYPENAME(FN(float(input[thread_position_in_grid]))); \ }\ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -83,7 +83,7 @@ kernel void FN_NAME_STRIDED( \ if (thread_position_in_grid >= dim) { \ return; \ } \ - output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \ + output[thread_position_in_grid] = TYPENAME(FN(float(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)]))); \ } #define UNARY_OP(NAME) \ @@ -107,6 +107,7 @@ UNARY_OP(floor) UNARY_OP(round) UNARY_OP(gelu_erf) UNARY_OP(erf) +UNARY_OP(tanh) UNARY(id, float, copy_float, copy_float_strided) UNARY(id, half, copy_half, copy_half_strided) UNARY(id, uint8_t, copy_u8, copy_u8_strided) @@ -126,6 +127,7 @@ BFLOAT_UNARY_OP(floor) BFLOAT_UNARY_OP(round) BFLOAT_UNARY_OP(gelu_erf) BFLOAT_UNARY_OP(erf) +BFLOAT_UNARY_OP(tanh) UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided) #endif diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index 45298907..03622752 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -19,6 +19,7 @@ num-traits = { workspace = true } rayon = { workspace = true } safetensors = { workspace = true } serde = { workspace = true } +metal = { workspace = true, optional = true } candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true } [dev-dependencies] @@ -30,4 +31,4 @@ default = [] accelerate = ["dep:accelerate-src", "candle/accelerate"] cuda = ["candle/cuda"] mkl = ["dep:intel-mkl-src", "candle/mkl"] -metal = ["candle/metal", "dep:candle-metal-kernels"] +metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"] diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 350bc663..14dd10de 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -226,7 +226,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { let last_dim = layout.dims()[layout.shape().rank() - 1]; let elem_count = layout.shape().elem_count(); - let mut output = device.new_buffer(elem_count, storage.dtype()); + let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax"); candle_metal_kernels::call_last_softmax( device.metal_device(), &command_buffer, @@ -238,6 +238,8 @@ impl candle::CustomOp1 for SoftmaxLastDim { &mut output, ) .unwrap(); + command_buffer.commit(); + output.did_modify_range(metal::NSRange::new(0, output.length())); let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype()); Ok((newstorage, layout.shape().clone())) } diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index af4e04b7..e72cab69 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -31,3 +31,4 @@ accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"] cuda = ["candle/cuda", "candle-nn/cuda"] flash-attn = ["cuda", "dep:candle-flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"] +metal = ["candle/metal", "candle-nn/metal"] diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index e822ca14..c8dae511 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -142,10 +142,9 @@ impl RotaryEmbedding { .to_dtype(DType::F32)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; - Ok(Self { - sin: freqs.sin()?, - cos: freqs.cos()?, - }) + let sin = freqs.sin()?; + let cos = freqs.cos()?; + Ok(Self { sin, cos }) } fn apply_rotary_emb_qkv( @@ -273,6 +272,10 @@ impl MHA { } fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> { + let view = xs.to_string(); + if view.contains("NaN") { + panic!("NaN"); + } let _enter = self.span.enter(); let (b_size, seq_len, _n_embd) = xs.dims3()?; let qkv = self @@ -408,3 +411,38 @@ impl MixFormerSequentialForCausalLM { self.blocks.iter_mut().for_each(|b| b.clear_kv_cache()) } } + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_rotary() { + let dev = Device::new_metal(0).unwrap(); + for i in 0..10000 { + let dim = 8; + let max_seq_len = 12; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), &dev).unwrap(); + let t = Tensor::arange(0u32, max_seq_len as u32, &dev) + .unwrap() + .to_dtype(DType::F32) + .unwrap() + .reshape((max_seq_len, 1)) + .unwrap(); + let x: f32 = t.i((1, 0)).unwrap().to_scalar().unwrap(); + assert_eq!(x, 1.0); + let x: f32 = inv_freq.i((0, 1)).unwrap().to_scalar().unwrap(); + assert_eq!(x, 0.1); + let freqs = t.matmul(&inv_freq).unwrap(); + let x: f32 = freqs.i((1, 1)).unwrap().to_scalar().unwrap(); + assert_eq!(x, 0.1); + let sin = freqs.sin().unwrap().contiguous().unwrap(); + let x: f32 = sin.i((1, 1)).unwrap().to_scalar().unwrap(); + assert_eq!(x, 0.099833414); + } + } +} |