diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-12-14 16:05:33 +0100 |
---|---|---|
committer | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-12-14 16:05:33 +0100 |
commit | 361f2ad2af52ccf1750e274f1649fb8c90f80a86 (patch) | |
tree | 6e919f0df7076abd021bd22e595b811f404bd8d3 | |
parent | 931432ed55918886680e37a280c3ff25d7ee9839 (diff) | |
download | candle-361f2ad2af52ccf1750e274f1649fb8c90f80a86.tar.gz candle-361f2ad2af52ccf1750e274f1649fb8c90f80a86.tar.bz2 candle-361f2ad2af52ccf1750e274f1649fb8c90f80a86.zip |
Working with merging encoders and using fences.
-rw-r--r-- | candle-core/src/metal_backend.rs | 120 | ||||
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 2 | ||||
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 40 | ||||
-rw-r--r-- | candle-metal-kernels/src/test.swift | 209 | ||||
-rw-r--r-- | candle-nn/src/ops.rs | 2 |
5 files changed, 279 insertions, 94 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 9866f1ca..4bc80823 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -38,6 +38,7 @@ pub struct MetalDevice { command_queue: metal::CommandQueue, command_buffers: Arc<RwLock<Vec<metal::CommandBuffer>>>, command_buffer_index: Arc<RwLock<usize>>, + fence: metal::Fence, kernels: Arc<candle_metal_kernels::Kernels>, buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>, } @@ -71,68 +72,32 @@ impl MetalDevice { pub fn command_buffer(&self) -> CommandBuffer { let mut command_buffers = self.command_buffers.try_write().unwrap(); + let mut command_buffer = command_buffers[0].to_owned(); let mut index = self.command_buffer_index.try_write().unwrap(); - let n = command_buffers.len(); - if *index == n { - // todo!("Cycle buffers"); - for i in 0..n { - let command_buffer = &command_buffers[i]; - match command_buffer.status() { - metal::MTLCommandBufferStatus::Committed - | metal::MTLCommandBufferStatus::Scheduled => { - // println!("Wait during cycling {i}"); - // println!("Command {i} / {n}: {:?}", command_buffer.status()); - command_buffer.wait_until_completed(); - } - metal::MTLCommandBufferStatus::Completed => {} - _ => { - panic!("Command buffer {i} not committed during cycling"); - } - } - } - let new_buffers = (0..n) - .map(|i| { - // println!("Creating command buffer {i}"); - let command_buffer = self.command_queue.new_command_buffer().to_owned(); - command_buffer.set_label(&format!("num {i}")); - command_buffer.enqueue(); - command_buffer - }) - .collect(); - *command_buffers = new_buffers; + if *index > 20 { + command_buffer.commit(); + command_buffer = self.command_queue.new_command_buffer().to_owned(); + *command_buffers = vec![command_buffer.clone()]; *index = 0; - // println!("Reset"); } - // println!("Giving buffer {} / {n}", *index); - let out = &command_buffers[*index]; - assert_eq!(out.status(), metal::MTLCommandBufferStatus::Enqueued); *index += 1; - out.to_owned() + command_buffer } pub fn wait_until_completed(&self) { - let command_buffers = self.command_buffers.try_write().unwrap(); - let index = self.command_buffer_index.try_write().unwrap(); - // let n = command_buffers.len(); - // for i in 0..*index { - // let command_buffer = &command_buffers[i]; - // println!("Command {i} / {n}: {:?}", command_buffer.status()); - // } - for i in 0..*index { - let command_buffer = &command_buffers[i]; - match command_buffer.status() { - metal::MTLCommandBufferStatus::Committed - | metal::MTLCommandBufferStatus::Scheduled => {} - metal::MTLCommandBufferStatus::Completed => {} - _ => { - panic!("Command buffer not committed"); - } + let mut command_buffers = self.command_buffers.try_write().unwrap(); + let command_buffer = &command_buffers[0]; + match command_buffer.status() { + metal::MTLCommandBufferStatus::Committed + | metal::MTLCommandBufferStatus::Scheduled + | metal::MTLCommandBufferStatus::Completed => { + panic!("Alredy committed"); } - // println!("Wait {i}"); - command_buffer.wait_until_completed(); - // println!("Ok {i}"); - // command_buffer.wait_until_completed(); + _ => {} } + command_buffer.commit(); + command_buffer.wait_until_completed(); + *command_buffers = vec![self.command_queue.new_command_buffer().to_owned()]; } pub fn kernels(&self) -> &Kernels { @@ -176,7 +141,7 @@ impl MetalDevice { } pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> { - self._new_buffer(size, MTLResourceOptions::StorageModeShared, "managed") + self._new_buffer(size, MTLResourceOptions::StorageModeManaged, "managed") } pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> { @@ -184,7 +149,7 @@ impl MetalDevice { let tmp = self.device.new_buffer_with_data( data.as_ptr() as *const core::ffi::c_void, size, - metal::MTLResourceOptions::StorageModeShared, + metal::MTLResourceOptions::StorageModeManaged, ); let real = self._new_buffer( size, @@ -194,15 +159,15 @@ impl MetalDevice { let command_buffer = self.command_buffer(); command_buffer.set_label("with_data"); let blit = command_buffer.new_blit_command_encoder(); + blit.wait_for_fence(&self.fence); blit.set_label("with_data_blit"); blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); + blit.update_fence(&self.fence); blit.end_encoding(); - command_buffer.commit(); - drop(command_buffer); + // drop(command_buffer); // real.did_modify_range(metal::NSRange::new(0, real.length())); // println!("Command {:?}", command.status()); - // self.commit(); // This is necessary, for mmaped safetensors // Because of the unsafe slice cast we're doing. // The slice might not live long enough for metal @@ -259,19 +224,16 @@ impl BackendStorage for MetalStorage { self.dtype ); } - self.device.wait_until_completed(); - self.buffer - .did_modify_range(metal::NSRange::new(0, self.buffer.length())); let buffer = self.device.new_buffer_managed(self.buffer.length()); { let command_buffer = self.device.command_buffer(); command_buffer.set_label("to_cpu"); let blit = command_buffer.new_blit_command_encoder(); blit.set_label("blit_to_cpu"); + blit.wait_for_fence(&self.device.fence); blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); + blit.update_fence(&self.device.fence); blit.end_encoding(); - - command_buffer.commit(); } self.device.wait_until_completed(); @@ -338,8 +300,7 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); + // buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -389,8 +350,6 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -440,7 +399,6 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - command_buffer.commit(); buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -504,8 +462,6 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device, dtype)) } @@ -519,7 +475,6 @@ impl BackendStorage for MetalStorage { let shape = layout.shape(); let el_count = shape.elem_count(); let buffer = device.new_buffer(el_count, dtype, "todtype"); - device.wait_until_completed(); let command_buffer = device.command_buffer(); if layout.is_contiguous() && layout.start_offset() == 0 { let kernel_name = match (self.dtype, dtype) { @@ -564,10 +519,6 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } command_buffer.set_label("to_dtype"); - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); - device.wait_until_completed(); - Ok(Self::new(buffer, device.clone(), dtype)) } @@ -668,8 +619,6 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -752,8 +701,6 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } command_buffer.set_label("binary"); - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -798,8 +745,6 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device, dtype)) } @@ -909,8 +854,6 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -963,8 +906,6 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; // Create kernel - command_buffer.commit(); - self.device.wait_until_completed(); Ok(Self::new(buffer, self.device.clone(), self.dtype())) } @@ -1010,7 +951,6 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; command_buffer.set_label("copy_strided"); } - command_buffer.commit(); Ok(()) } } @@ -1036,7 +976,7 @@ impl BackendDevice for MetalDevice { // println!("CREATING DEVICE"); let device = metal::Device::all().swap_remove(ordinal); - let n = 64; + let n = 1; let command_queue = device.new_command_queue(); let command_buffers = (0..n) @@ -1049,10 +989,12 @@ impl BackendDevice for MetalDevice { .collect(); let command_buffers = Arc::new(RwLock::new(command_buffers)); let command_buffer_index = Arc::new(RwLock::new(0)); - let kernels = Arc::new(Kernels::new()); + let fence = device.new_fence(); + let kernels = Arc::new(Kernels::new(fence.clone())); let buffers = Arc::new(RwLock::new(HashMap::new())); Ok(Self { device, + fence, command_queue, command_buffers, command_buffer_index, @@ -1089,8 +1031,6 @@ impl BackendDevice for MetalDevice { 0, ); blit.end_encoding(); - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(MetalStorage::new(buffer, self.clone(), dtype)) } diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index c871dc96..a77f9c3a 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -900,7 +900,9 @@ fn matmul(device: &Device) -> Result<()> { let b = Tensor::from_slice(&data, (2, 2), device)?; let c = a.matmul(&b)?; + let d = a.matmul(&c)?; assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]); + assert_eq!(d.to_vec2::<f32>()?, &[[37.0, 54.0], [81.0, 118.0]]); let data = vec![1.0f32, 2.0]; let a = Tensor::from_slice(&data, (2, 1), device)?; diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index b80dcb79..01432ccb 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -184,19 +184,21 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError { type Libraries = HashMap<Source, Library>; type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipelineState>; -#[derive(Debug, Default)] +#[derive(Debug)] pub struct Kernels { libraries: RwLock<Libraries>, pipelines: RwLock<Pipelines>, + fence: metal::Fence, } impl Kernels { - pub fn new() -> Self { + pub fn new(fence: metal::Fence) -> Self { let libraries = RwLock::new(Libraries::new()); let pipelines = RwLock::new(Pipelines::new()); Self { libraries, pipelines, + fence, } } @@ -304,12 +306,14 @@ pub fn call_unary_contiguous( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -331,6 +335,7 @@ pub fn call_unary_strided( let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -350,6 +355,7 @@ pub fn call_unary_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -368,6 +374,7 @@ pub fn call_binary_contiguous( let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, left, right, output)); @@ -375,6 +382,7 @@ pub fn call_binary_contiguous( let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -399,6 +407,7 @@ pub fn call_binary_strided( let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); let width: usize = shape.iter().product(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -420,6 +429,7 @@ pub fn call_binary_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -438,12 +448,14 @@ pub fn call_cast_contiguous( let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, (input, input_offset), output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -463,6 +475,7 @@ pub fn call_cast_strided( let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -482,6 +495,7 @@ pub fn call_cast_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -501,6 +515,7 @@ pub fn call_reduce_contiguous( let elements_to_sum = length / out_length; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -527,6 +542,7 @@ pub fn call_reduce_contiguous( }; encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -544,6 +560,7 @@ pub fn call_last_softmax( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, elements_to_sum, input, output)); @@ -569,6 +586,7 @@ pub fn call_last_softmax( }; encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -588,12 +606,14 @@ pub fn call_affine( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, add, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -616,6 +636,7 @@ pub fn call_affine_strided( let size: usize = shape.iter().product(); let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -634,6 +655,7 @@ pub fn call_affine_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -652,12 +674,14 @@ pub fn call_powf( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -679,6 +703,7 @@ pub fn call_powf_strided( let size: usize = shape.iter().product(); let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -696,6 +721,7 @@ pub fn call_powf_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -714,12 +740,14 @@ pub fn call_elu( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -741,6 +769,7 @@ pub fn call_elu_strided( let size: usize = shape.iter().product(); let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -758,6 +787,7 @@ pub fn call_elu_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -779,6 +809,7 @@ pub fn call_where_cond_strided( let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let size: usize = shape.iter().product(); @@ -803,6 +834,7 @@ pub fn call_where_cond_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -829,6 +861,7 @@ pub fn call_index_select( let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -848,6 +881,7 @@ pub fn call_index_select( let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -1045,6 +1079,7 @@ pub fn call_gemm( let block_bytes = block_elements * bytes; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); // println!("Threadgroup {block_bytes}"); encoder.set_threadgroup_memory_length(0, block_bytes.into()); @@ -1087,6 +1122,7 @@ pub fn call_gemm( }; // println!("grid size {grid_size:?} group size {group_size:?}"); encoder.dispatch_thread_groups(grid_size, group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) diff --git a/candle-metal-kernels/src/test.swift b/candle-metal-kernels/src/test.swift new file mode 100644 index 00000000..f9bb9f91 --- /dev/null +++ b/candle-metal-kernels/src/test.swift @@ -0,0 +1,209 @@ + +import Metal +import MetalPerformanceShadersGraph + + + +let type = MTLDataType.float; +let dataType = type; +var B = 2; +var M = 2; +var N = 2; +var K = 2; +var A_trans = false; +var B_trans = false; +var D_trans = false; +var alpha = Float(1.0); +var beta = Float(0.0); +var batched = B > 1; +var fused_activation = false; +var fused_bias = false; +let constants = MTLFunctionConstantValues() +constants.setConstantValue(&M, type: .uint, index: 0) +constants.setConstantValue(&N, type: .uint, index: 1) +constants.setConstantValue(&K, type: .uint, index: 2) +constants.setConstantValue(&A_trans, type: .bool, index: 10) +constants.setConstantValue(&B_trans, type: .bool, index: 11) +constants.setConstantValue(&D_trans, type: .bool, index: 13) +constants.setConstantValue(&alpha, type: .float, index: 20) +constants.setConstantValue(&beta, type: .float, index: 21) +constants.setConstantValue(&batched, type: .bool, index: 100) +constants.setConstantValue(&fused_activation, type: .bool, index: 101) +constants.setConstantValue(&fused_bias, type: .bool, index: 50001) + + +var M_simd = UInt16(16) +var N_simd = UInt16(16) +var K_simd = UInt16(32) +var M_splits = UInt16(2) +var N_splits = UInt16(2) +constants.setConstantValue(&M_simd, type: .ushort, index: 200) +constants.setConstantValue(&N_simd, type: .ushort, index: 201) +constants.setConstantValue(&K_simd, type: .ushort, index: 202) +constants.setConstantValue(&M_splits, type: .ushort, index: 210) +constants.setConstantValue(&N_splits, type: .ushort, index: 211) + +let M_group = M_simd * M_splits +let N_group = N_simd * N_splits + +// Satisfy Metal API validation. +#if DEBUG +do { + var garbage: SIMD4<UInt64> = .zero + constants.setConstantValue(&garbage, type: .bool, index: 102) + constants.setConstantValue(&garbage, type: .bool, index: 103) + constants.setConstantValue(&garbage, type: .bool, index: 113) + constants.setConstantValue(&garbage, type: .bool, index: 50000) +} +#endif + +let device = MTLCopyAllDevices().first! +device.shouldMaximizeConcurrentCompilation = true + +var libraryURL = URL.init(string: "/Users/nicolas/src/candle/candle-metal-kernels/")!; +libraryURL.append(component: "src") +libraryURL.append(component: "libMetalFlashAttention.metallib") +let library = try! device.makeLibrary(URL: libraryURL) + +var name: String + switch dataType { + case .half: name = "hgemm" + case .float: name = "sgemm" + default: fatalError() + } +let function = try! library.makeFunction( + name: name, constantValues: constants) + +let A_block_length = M_group * K_simd +let B_block_length = K_simd * N_group + +var blockElements = A_block_length + B_block_length; +if (M % 8 != 0) && (N % 8 != 0) { + let C_block_length = M_group * N_group; + blockElements = max(C_block_length, blockElements) +} +if fused_bias { + if D_trans { + blockElements = max(blockElements, M_group) + } else { + blockElements = max(blockElements, N_group) + } +} +// let blockBytes = blockElements * UInt16(dataType.size) +let elementSize = 4 +let blockBytes = blockElements * UInt16(elementSize) + +func ceilDivide(target: Int, granularity: UInt16) -> Int { + (target + Int(granularity) - 1) / Int(granularity) +} +var gridSize = MTLSize( + width: ceilDivide(target: N, granularity: N_group), + height: ceilDivide(target: M, granularity: M_group), + depth: 1) +let groupSize = MTLSize( + width: Int(32 * M_splits * N_splits), + height: 1, + depth: 1) + +let commandQueue = device.makeCommandQueue()! + +let threadgroupMemoryLength = blockBytes; + +let rowsA = M; +let columnsA = K; +let rowsB = K; +let columnsB = N; +let rowsC = M; +let columnsC = N; +var arrayA = [Float](repeating: 0, count: B * rowsA * columnsA) + +var arrayB = [Float](repeating: 0, count: B * rowsB * columnsB) + +var arrayC = [Float](repeating: 0, count: B * rowsC * columnsC) +var arrayD = [Float](repeating: 0, count: B * rowsC * columnsC) +for i in 0..<arrayA.count { + arrayA[i] = Float(i) +} + +for i in 0..<arrayB.count { + arrayB[i] = Float(i) +} + +let bufferA = device.makeBuffer(bytes: arrayA, length: B * rowsA * columnsA * MemoryLayout<Float>.stride, options: [])! + +let bufferB = device.makeBuffer(bytes: arrayB, length: B * rowsB * columnsB * MemoryLayout<Float>.stride, options: [])! + +let bufferC = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout<Float>.stride, options: [])! +let bufferD = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout<Float>.stride, options: [])! + + +let pipeline = try device.makeComputePipelineState(function: function) + +func call(bufferA: MTLBuffer, bufferB: MTLBuffer, bufferC: MTLBuffer){ + let encoder = commandBuffer.makeComputeCommandEncoder(dispatchType: MTLDispatchType.serial)! + encoder.setComputePipelineState(pipeline) + encoder.setThreadgroupMemoryLength(Int(threadgroupMemoryLength), index: 0) + + encoder.setBuffer(bufferA, offset: 0, index: 0) + encoder.setBuffer(bufferB, offset: 0, index: 1) + encoder.setBuffer(bufferC, offset: 0, index: 2) + let gridZ: Int = B + if batched{ + func byteStride(shape: [Int]) -> Int { + let rank = shape.count + var output = elementSize * shape[rank - 2] * shape[rank - 1] + if shape.dropLast(2).reduce(1, *) == 1 { + output = 0 + } + return output + } + let byteStrideA = M*K*elementSize + let byteStrideB = N*K*elementSize + let byteStrideC = M*N*elementSize + + let byteStrideD = 0 + withUnsafeTemporaryAllocation( + of: SIMD4<UInt64>.self, capacity: gridZ + ) { buffer in + for i in 0..<buffer.count { + buffer[i] = SIMD4( + UInt64(truncatingIfNeeded: i * byteStrideA), + UInt64(truncatingIfNeeded: i * byteStrideB), + UInt64(truncatingIfNeeded: i * byteStrideC), + UInt64(truncatingIfNeeded: i * byteStrideD)) + } + + let bufferLength = buffer.count * MemoryLayout<SIMD4<UInt64>>.stride + assert(MemoryLayout<SIMD4<UInt64>>.stride == 8 * 4) + encoder.setBytes(buffer.baseAddress!, length: bufferLength, index: 10) + } + } + gridSize.depth = gridZ + + + encoder.dispatchThreadgroups( + gridSize, threadsPerThreadgroup: groupSize + ) + encoder.endEncoding() +} + +var commandBuffer = commandQueue.makeCommandBuffer()! +call(bufferA:bufferA, bufferB:bufferB, bufferC:bufferC) +commandBuffer.commit() +commandBuffer = commandQueue.makeCommandBuffer()! +commandBuffer.encodeWaitForEvent(event, value: 2) +call(bufferA:bufferA, bufferB:bufferC, bufferC:bufferD) +commandBuffer.commit() + +commandBuffer.waitUntilCompleted() +var contents = bufferC.contents(); +var count = B * rowsA * columnsB; +var typedPointer = contents.bindMemory(to: Float.self, capacity: count) +var bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count) +print("First matmul is OK", Array(bufferedPointer)) + +contents = bufferD.contents(); +count = B * rowsA * columnsB; +typedPointer = contents.bindMemory(to: Float.self, capacity: count) +bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count) +print("This should be filled", Array(bufferedPointer)) diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 14dd10de..e002d931 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -238,8 +238,6 @@ impl candle::CustomOp1 for SoftmaxLastDim { &mut output, ) .unwrap(); - command_buffer.commit(); - output.did_modify_range(metal::NSRange::new(0, output.length())); let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype()); Ok((newstorage, layout.shape().clone())) } |