summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-12-14 16:05:33 +0100
committerNicolas Patry <patry.nicolas@protonmail.com>2023-12-14 16:05:33 +0100
commit361f2ad2af52ccf1750e274f1649fb8c90f80a86 (patch)
tree6e919f0df7076abd021bd22e595b811f404bd8d3
parent931432ed55918886680e37a280c3ff25d7ee9839 (diff)
downloadcandle-361f2ad2af52ccf1750e274f1649fb8c90f80a86.tar.gz
candle-361f2ad2af52ccf1750e274f1649fb8c90f80a86.tar.bz2
candle-361f2ad2af52ccf1750e274f1649fb8c90f80a86.zip
Working with merging encoders and using fences.
-rw-r--r--candle-core/src/metal_backend.rs120
-rw-r--r--candle-core/tests/tensor_tests.rs2
-rw-r--r--candle-metal-kernels/src/lib.rs40
-rw-r--r--candle-metal-kernels/src/test.swift209
-rw-r--r--candle-nn/src/ops.rs2
5 files changed, 279 insertions, 94 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index 9866f1ca..4bc80823 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -38,6 +38,7 @@ pub struct MetalDevice {
command_queue: metal::CommandQueue,
command_buffers: Arc<RwLock<Vec<metal::CommandBuffer>>>,
command_buffer_index: Arc<RwLock<usize>>,
+ fence: metal::Fence,
kernels: Arc<candle_metal_kernels::Kernels>,
buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>,
}
@@ -71,68 +72,32 @@ impl MetalDevice {
pub fn command_buffer(&self) -> CommandBuffer {
let mut command_buffers = self.command_buffers.try_write().unwrap();
+ let mut command_buffer = command_buffers[0].to_owned();
let mut index = self.command_buffer_index.try_write().unwrap();
- let n = command_buffers.len();
- if *index == n {
- // todo!("Cycle buffers");
- for i in 0..n {
- let command_buffer = &command_buffers[i];
- match command_buffer.status() {
- metal::MTLCommandBufferStatus::Committed
- | metal::MTLCommandBufferStatus::Scheduled => {
- // println!("Wait during cycling {i}");
- // println!("Command {i} / {n}: {:?}", command_buffer.status());
- command_buffer.wait_until_completed();
- }
- metal::MTLCommandBufferStatus::Completed => {}
- _ => {
- panic!("Command buffer {i} not committed during cycling");
- }
- }
- }
- let new_buffers = (0..n)
- .map(|i| {
- // println!("Creating command buffer {i}");
- let command_buffer = self.command_queue.new_command_buffer().to_owned();
- command_buffer.set_label(&format!("num {i}"));
- command_buffer.enqueue();
- command_buffer
- })
- .collect();
- *command_buffers = new_buffers;
+ if *index > 20 {
+ command_buffer.commit();
+ command_buffer = self.command_queue.new_command_buffer().to_owned();
+ *command_buffers = vec![command_buffer.clone()];
*index = 0;
- // println!("Reset");
}
- // println!("Giving buffer {} / {n}", *index);
- let out = &command_buffers[*index];
- assert_eq!(out.status(), metal::MTLCommandBufferStatus::Enqueued);
*index += 1;
- out.to_owned()
+ command_buffer
}
pub fn wait_until_completed(&self) {
- let command_buffers = self.command_buffers.try_write().unwrap();
- let index = self.command_buffer_index.try_write().unwrap();
- // let n = command_buffers.len();
- // for i in 0..*index {
- // let command_buffer = &command_buffers[i];
- // println!("Command {i} / {n}: {:?}", command_buffer.status());
- // }
- for i in 0..*index {
- let command_buffer = &command_buffers[i];
- match command_buffer.status() {
- metal::MTLCommandBufferStatus::Committed
- | metal::MTLCommandBufferStatus::Scheduled => {}
- metal::MTLCommandBufferStatus::Completed => {}
- _ => {
- panic!("Command buffer not committed");
- }
+ let mut command_buffers = self.command_buffers.try_write().unwrap();
+ let command_buffer = &command_buffers[0];
+ match command_buffer.status() {
+ metal::MTLCommandBufferStatus::Committed
+ | metal::MTLCommandBufferStatus::Scheduled
+ | metal::MTLCommandBufferStatus::Completed => {
+ panic!("Alredy committed");
}
- // println!("Wait {i}");
- command_buffer.wait_until_completed();
- // println!("Ok {i}");
- // command_buffer.wait_until_completed();
+ _ => {}
}
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+ *command_buffers = vec![self.command_queue.new_command_buffer().to_owned()];
}
pub fn kernels(&self) -> &Kernels {
@@ -176,7 +141,7 @@ impl MetalDevice {
}
pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> {
- self._new_buffer(size, MTLResourceOptions::StorageModeShared, "managed")
+ self._new_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
}
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> {
@@ -184,7 +149,7 @@ impl MetalDevice {
let tmp = self.device.new_buffer_with_data(
data.as_ptr() as *const core::ffi::c_void,
size,
- metal::MTLResourceOptions::StorageModeShared,
+ metal::MTLResourceOptions::StorageModeManaged,
);
let real = self._new_buffer(
size,
@@ -194,15 +159,15 @@ impl MetalDevice {
let command_buffer = self.command_buffer();
command_buffer.set_label("with_data");
let blit = command_buffer.new_blit_command_encoder();
+ blit.wait_for_fence(&self.fence);
blit.set_label("with_data_blit");
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
+ blit.update_fence(&self.fence);
blit.end_encoding();
- command_buffer.commit();
- drop(command_buffer);
+ // drop(command_buffer);
// real.did_modify_range(metal::NSRange::new(0, real.length()));
// println!("Command {:?}", command.status());
- // self.commit();
// This is necessary, for mmaped safetensors
// Because of the unsafe slice cast we're doing.
// The slice might not live long enough for metal
@@ -259,19 +224,16 @@ impl BackendStorage for MetalStorage {
self.dtype
);
}
- self.device.wait_until_completed();
- self.buffer
- .did_modify_range(metal::NSRange::new(0, self.buffer.length()));
let buffer = self.device.new_buffer_managed(self.buffer.length());
{
let command_buffer = self.device.command_buffer();
command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu");
+ blit.wait_for_fence(&self.device.fence);
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
+ blit.update_fence(&self.device.fence);
blit.end_encoding();
-
- command_buffer.commit();
}
self.device.wait_until_completed();
@@ -338,8 +300,7 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
}
- command_buffer.commit();
- buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
+ // buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
Ok(Self::new(buffer, device.clone(), dtype))
}
@@ -389,8 +350,6 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
}
- command_buffer.commit();
- buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
Ok(Self::new(buffer, device.clone(), dtype))
}
@@ -440,7 +399,6 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
}
- command_buffer.commit();
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
Ok(Self::new(buffer, device.clone(), dtype))
}
@@ -504,8 +462,6 @@ impl BackendStorage for MetalStorage {
&buffer,
)
.map_err(MetalError::from)?;
- command_buffer.commit();
- buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
Ok(Self::new(buffer, device, dtype))
}
@@ -519,7 +475,6 @@ impl BackendStorage for MetalStorage {
let shape = layout.shape();
let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype, "todtype");
- device.wait_until_completed();
let command_buffer = device.command_buffer();
if layout.is_contiguous() && layout.start_offset() == 0 {
let kernel_name = match (self.dtype, dtype) {
@@ -564,10 +519,6 @@ impl BackendStorage for MetalStorage {
.map_err(MetalError::from)?;
}
command_buffer.set_label("to_dtype");
- command_buffer.commit();
- buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
- device.wait_until_completed();
-
Ok(Self::new(buffer, device.clone(), dtype))
}
@@ -668,8 +619,6 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
}
- command_buffer.commit();
- buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
Ok(Self::new(buffer, device.clone(), dtype))
}
@@ -752,8 +701,6 @@ impl BackendStorage for MetalStorage {
.map_err(MetalError::from)?;
}
command_buffer.set_label("binary");
- command_buffer.commit();
- buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
Ok(Self::new(buffer, device.clone(), dtype))
}
@@ -798,8 +745,6 @@ impl BackendStorage for MetalStorage {
&buffer,
)
.map_err(MetalError::from)?;
- command_buffer.commit();
- buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
Ok(Self::new(buffer, device, dtype))
}
@@ -909,8 +854,6 @@ impl BackendStorage for MetalStorage {
&buffer,
)
.map_err(MetalError::from)?;
- command_buffer.commit();
- buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
Ok(Self::new(buffer, device.clone(), dtype))
}
@@ -963,8 +906,6 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
// Create kernel
- command_buffer.commit();
- self.device.wait_until_completed();
Ok(Self::new(buffer, self.device.clone(), self.dtype()))
}
@@ -1010,7 +951,6 @@ impl BackendStorage for MetalStorage {
.map_err(MetalError::from)?;
command_buffer.set_label("copy_strided");
}
- command_buffer.commit();
Ok(())
}
}
@@ -1036,7 +976,7 @@ impl BackendDevice for MetalDevice {
// println!("CREATING DEVICE");
let device = metal::Device::all().swap_remove(ordinal);
- let n = 64;
+ let n = 1;
let command_queue = device.new_command_queue();
let command_buffers = (0..n)
@@ -1049,10 +989,12 @@ impl BackendDevice for MetalDevice {
.collect();
let command_buffers = Arc::new(RwLock::new(command_buffers));
let command_buffer_index = Arc::new(RwLock::new(0));
- let kernels = Arc::new(Kernels::new());
+ let fence = device.new_fence();
+ let kernels = Arc::new(Kernels::new(fence.clone()));
let buffers = Arc::new(RwLock::new(HashMap::new()));
Ok(Self {
device,
+ fence,
command_queue,
command_buffers,
command_buffer_index,
@@ -1089,8 +1031,6 @@ impl BackendDevice for MetalDevice {
0,
);
blit.end_encoding();
- command_buffer.commit();
- buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
Ok(MetalStorage::new(buffer, self.clone(), dtype))
}
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index c871dc96..a77f9c3a 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -900,7 +900,9 @@ fn matmul(device: &Device) -> Result<()> {
let b = Tensor::from_slice(&data, (2, 2), device)?;
let c = a.matmul(&b)?;
+ let d = a.matmul(&c)?;
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
+ assert_eq!(d.to_vec2::<f32>()?, &[[37.0, 54.0], [81.0, 118.0]]);
let data = vec![1.0f32, 2.0];
let a = Tensor::from_slice(&data, (2, 1), device)?;
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index b80dcb79..01432ccb 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -184,19 +184,21 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
type Libraries = HashMap<Source, Library>;
type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipelineState>;
-#[derive(Debug, Default)]
+#[derive(Debug)]
pub struct Kernels {
libraries: RwLock<Libraries>,
pipelines: RwLock<Pipelines>,
+ fence: metal::Fence,
}
impl Kernels {
- pub fn new() -> Self {
+ pub fn new(fence: metal::Fence) -> Self {
let libraries = RwLock::new(Libraries::new());
let pipelines = RwLock::new(Pipelines::new());
Self {
libraries,
pipelines,
+ fence,
}
}
@@ -304,12 +306,14 @@ pub fn call_unary_contiguous(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -331,6 +335,7 @@ pub fn call_unary_strided(
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
@@ -350,6 +355,7 @@ pub fn call_unary_strided(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -368,6 +374,7 @@ pub fn call_binary_contiguous(
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, left, right, output));
@@ -375,6 +382,7 @@ pub fn call_binary_contiguous(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -399,6 +407,7 @@ pub fn call_binary_strided(
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
let width: usize = shape.iter().product();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
@@ -420,6 +429,7 @@ pub fn call_binary_strided(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -438,12 +448,14 @@ pub fn call_cast_contiguous(
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, (input, input_offset), output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -463,6 +475,7 @@ pub fn call_cast_strided(
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
@@ -482,6 +495,7 @@ pub fn call_cast_strided(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -501,6 +515,7 @@ pub fn call_reduce_contiguous(
let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -527,6 +542,7 @@ pub fn call_reduce_contiguous(
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -544,6 +560,7 @@ pub fn call_last_softmax(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, elements_to_sum, input, output));
@@ -569,6 +586,7 @@ pub fn call_last_softmax(
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -588,12 +606,14 @@ pub fn call_affine(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, add, input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -616,6 +636,7 @@ pub fn call_affine_strided(
let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -634,6 +655,7 @@ pub fn call_affine_strided(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -652,12 +674,14 @@ pub fn call_powf(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -679,6 +703,7 @@ pub fn call_powf_strided(
let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -696,6 +721,7 @@ pub fn call_powf_strided(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -714,12 +740,14 @@ pub fn call_elu(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -741,6 +769,7 @@ pub fn call_elu_strided(
let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -758,6 +787,7 @@ pub fn call_elu_strided(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -779,6 +809,7 @@ pub fn call_where_cond_strided(
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let size: usize = shape.iter().product();
@@ -803,6 +834,7 @@ pub fn call_where_cond_strided(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -829,6 +861,7 @@ pub fn call_index_select(
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -848,6 +881,7 @@ pub fn call_index_select(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -1045,6 +1079,7 @@ pub fn call_gemm(
let block_bytes = block_elements * bytes;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
// println!("Threadgroup {block_bytes}");
encoder.set_threadgroup_memory_length(0, block_bytes.into());
@@ -1087,6 +1122,7 @@ pub fn call_gemm(
};
// println!("grid size {grid_size:?} group size {group_size:?}");
encoder.dispatch_thread_groups(grid_size, group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
diff --git a/candle-metal-kernels/src/test.swift b/candle-metal-kernels/src/test.swift
new file mode 100644
index 00000000..f9bb9f91
--- /dev/null
+++ b/candle-metal-kernels/src/test.swift
@@ -0,0 +1,209 @@
+
+import Metal
+import MetalPerformanceShadersGraph
+
+
+
+let type = MTLDataType.float;
+let dataType = type;
+var B = 2;
+var M = 2;
+var N = 2;
+var K = 2;
+var A_trans = false;
+var B_trans = false;
+var D_trans = false;
+var alpha = Float(1.0);
+var beta = Float(0.0);
+var batched = B > 1;
+var fused_activation = false;
+var fused_bias = false;
+let constants = MTLFunctionConstantValues()
+constants.setConstantValue(&M, type: .uint, index: 0)
+constants.setConstantValue(&N, type: .uint, index: 1)
+constants.setConstantValue(&K, type: .uint, index: 2)
+constants.setConstantValue(&A_trans, type: .bool, index: 10)
+constants.setConstantValue(&B_trans, type: .bool, index: 11)
+constants.setConstantValue(&D_trans, type: .bool, index: 13)
+constants.setConstantValue(&alpha, type: .float, index: 20)
+constants.setConstantValue(&beta, type: .float, index: 21)
+constants.setConstantValue(&batched, type: .bool, index: 100)
+constants.setConstantValue(&fused_activation, type: .bool, index: 101)
+constants.setConstantValue(&fused_bias, type: .bool, index: 50001)
+
+
+var M_simd = UInt16(16)
+var N_simd = UInt16(16)
+var K_simd = UInt16(32)
+var M_splits = UInt16(2)
+var N_splits = UInt16(2)
+constants.setConstantValue(&M_simd, type: .ushort, index: 200)
+constants.setConstantValue(&N_simd, type: .ushort, index: 201)
+constants.setConstantValue(&K_simd, type: .ushort, index: 202)
+constants.setConstantValue(&M_splits, type: .ushort, index: 210)
+constants.setConstantValue(&N_splits, type: .ushort, index: 211)
+
+let M_group = M_simd * M_splits
+let N_group = N_simd * N_splits
+
+// Satisfy Metal API validation.
+#if DEBUG
+do {
+ var garbage: SIMD4<UInt64> = .zero
+ constants.setConstantValue(&garbage, type: .bool, index: 102)
+ constants.setConstantValue(&garbage, type: .bool, index: 103)
+ constants.setConstantValue(&garbage, type: .bool, index: 113)
+ constants.setConstantValue(&garbage, type: .bool, index: 50000)
+}
+#endif
+
+let device = MTLCopyAllDevices().first!
+device.shouldMaximizeConcurrentCompilation = true
+
+var libraryURL = URL.init(string: "/Users/nicolas/src/candle/candle-metal-kernels/")!;
+libraryURL.append(component: "src")
+libraryURL.append(component: "libMetalFlashAttention.metallib")
+let library = try! device.makeLibrary(URL: libraryURL)
+
+var name: String
+ switch dataType {
+ case .half: name = "hgemm"
+ case .float: name = "sgemm"
+ default: fatalError()
+ }
+let function = try! library.makeFunction(
+ name: name, constantValues: constants)
+
+let A_block_length = M_group * K_simd
+let B_block_length = K_simd * N_group
+
+var blockElements = A_block_length + B_block_length;
+if (M % 8 != 0) && (N % 8 != 0) {
+ let C_block_length = M_group * N_group;
+ blockElements = max(C_block_length, blockElements)
+}
+if fused_bias {
+ if D_trans {
+ blockElements = max(blockElements, M_group)
+ } else {
+ blockElements = max(blockElements, N_group)
+ }
+}
+// let blockBytes = blockElements * UInt16(dataType.size)
+let elementSize = 4
+let blockBytes = blockElements * UInt16(elementSize)
+
+func ceilDivide(target: Int, granularity: UInt16) -> Int {
+ (target + Int(granularity) - 1) / Int(granularity)
+}
+var gridSize = MTLSize(
+ width: ceilDivide(target: N, granularity: N_group),
+ height: ceilDivide(target: M, granularity: M_group),
+ depth: 1)
+let groupSize = MTLSize(
+ width: Int(32 * M_splits * N_splits),
+ height: 1,
+ depth: 1)
+
+let commandQueue = device.makeCommandQueue()!
+
+let threadgroupMemoryLength = blockBytes;
+
+let rowsA = M;
+let columnsA = K;
+let rowsB = K;
+let columnsB = N;
+let rowsC = M;
+let columnsC = N;
+var arrayA = [Float](repeating: 0, count: B * rowsA * columnsA)
+
+var arrayB = [Float](repeating: 0, count: B * rowsB * columnsB)
+
+var arrayC = [Float](repeating: 0, count: B * rowsC * columnsC)
+var arrayD = [Float](repeating: 0, count: B * rowsC * columnsC)
+for i in 0..<arrayA.count {
+ arrayA[i] = Float(i)
+}
+
+for i in 0..<arrayB.count {
+ arrayB[i] = Float(i)
+}
+
+let bufferA = device.makeBuffer(bytes: arrayA, length: B * rowsA * columnsA * MemoryLayout<Float>.stride, options: [])!
+
+let bufferB = device.makeBuffer(bytes: arrayB, length: B * rowsB * columnsB * MemoryLayout<Float>.stride, options: [])!
+
+let bufferC = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout<Float>.stride, options: [])!
+let bufferD = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout<Float>.stride, options: [])!
+
+
+let pipeline = try device.makeComputePipelineState(function: function)
+
+func call(bufferA: MTLBuffer, bufferB: MTLBuffer, bufferC: MTLBuffer){
+ let encoder = commandBuffer.makeComputeCommandEncoder(dispatchType: MTLDispatchType.serial)!
+ encoder.setComputePipelineState(pipeline)
+ encoder.setThreadgroupMemoryLength(Int(threadgroupMemoryLength), index: 0)
+
+ encoder.setBuffer(bufferA, offset: 0, index: 0)
+ encoder.setBuffer(bufferB, offset: 0, index: 1)
+ encoder.setBuffer(bufferC, offset: 0, index: 2)
+ let gridZ: Int = B
+ if batched{
+ func byteStride(shape: [Int]) -> Int {
+ let rank = shape.count
+ var output = elementSize * shape[rank - 2] * shape[rank - 1]
+ if shape.dropLast(2).reduce(1, *) == 1 {
+ output = 0
+ }
+ return output
+ }
+ let byteStrideA = M*K*elementSize
+ let byteStrideB = N*K*elementSize
+ let byteStrideC = M*N*elementSize
+
+ let byteStrideD = 0
+ withUnsafeTemporaryAllocation(
+ of: SIMD4<UInt64>.self, capacity: gridZ
+ ) { buffer in
+ for i in 0..<buffer.count {
+ buffer[i] = SIMD4(
+ UInt64(truncatingIfNeeded: i * byteStrideA),
+ UInt64(truncatingIfNeeded: i * byteStrideB),
+ UInt64(truncatingIfNeeded: i * byteStrideC),
+ UInt64(truncatingIfNeeded: i * byteStrideD))
+ }
+
+ let bufferLength = buffer.count * MemoryLayout<SIMD4<UInt64>>.stride
+ assert(MemoryLayout<SIMD4<UInt64>>.stride == 8 * 4)
+ encoder.setBytes(buffer.baseAddress!, length: bufferLength, index: 10)
+ }
+ }
+ gridSize.depth = gridZ
+
+
+ encoder.dispatchThreadgroups(
+ gridSize, threadsPerThreadgroup: groupSize
+ )
+ encoder.endEncoding()
+}
+
+var commandBuffer = commandQueue.makeCommandBuffer()!
+call(bufferA:bufferA, bufferB:bufferB, bufferC:bufferC)
+commandBuffer.commit()
+commandBuffer = commandQueue.makeCommandBuffer()!
+commandBuffer.encodeWaitForEvent(event, value: 2)
+call(bufferA:bufferA, bufferB:bufferC, bufferC:bufferD)
+commandBuffer.commit()
+
+commandBuffer.waitUntilCompleted()
+var contents = bufferC.contents();
+var count = B * rowsA * columnsB;
+var typedPointer = contents.bindMemory(to: Float.self, capacity: count)
+var bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count)
+print("First matmul is OK", Array(bufferedPointer))
+
+contents = bufferD.contents();
+count = B * rowsA * columnsB;
+typedPointer = contents.bindMemory(to: Float.self, capacity: count)
+bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count)
+print("This should be filled", Array(bufferedPointer))
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index 14dd10de..e002d931 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -238,8 +238,6 @@ impl candle::CustomOp1 for SoftmaxLastDim {
&mut output,
)
.unwrap();
- command_buffer.commit();
- output.did_modify_range(metal::NSRange::new(0, output.length()));
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
Ok((newstorage, layout.shape().clone()))
}