summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorFL33TW00D <chris@fleetwood.dev>2024-01-22 15:15:19 +0000
committerFL33TW00D <chris@fleetwood.dev>2024-01-22 15:15:19 +0000
commitb6afb4660113b0633af372e7e10310377b677afd (patch)
tree41289bd709888d39e3bfb1fe42a664bc2c207310 /candle-metal-kernels
parent73d79e609226cbe5def96726f8a1896cf4b3dc5d (diff)
downloadcandle-b6afb4660113b0633af372e7e10310377b677afd.tar.gz
candle-b6afb4660113b0633af372e7e10310377b677afd.tar.bz2
candle-b6afb4660113b0633af372e7e10310377b677afd.zip
chore: final
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs25
-rw-r--r--candle-metal-kernels/src/libMetalFlashAttention.metallibbin116168 -> 116184 bytes
2 files changed, 10 insertions, 15 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 4c0f9223..2773ca6a 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1,7 +1,6 @@
use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
- Device, Function, FunctionConstantValues, Library, MTLDataType, MTLResourceOptions, MTLSize,
- NSUInteger,
+ Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
};
use std::collections::HashMap;
use std::ffi::c_void;
@@ -1360,21 +1359,17 @@ pub fn call_gemm(
// TODO byte_stride_d
let byte_stride_d = 0;
- let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
- for i in 0..b {
- buffer.push((i * byte_stride_a) as u64);
- buffer.push((i * byte_stride_b) as u64);
- buffer.push((i * byte_stride_c) as u64);
- buffer.push((i * byte_stride_d) as u64);
- }
-
- let matrix_offsets = device.new_buffer_with_data(
- buffer.as_ptr() as *const c_void,
+ let buffer: Vec<u64> = vec![
+ byte_stride_a as _,
+ byte_stride_b as _,
+ byte_stride_c as _,
+ byte_stride_d as _,
+ ];
+ encoder.set_bytes(
+ 10,
(buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
- MTLResourceOptions::StorageModeManaged,
+ buffer.as_ptr() as *const NSUInteger as *const c_void,
);
- encoder.set_buffer(10, Some(&matrix_offsets), 0);
- encoder.use_resource(&matrix_offsets, metal::MTLResourceUsage::Read);
}
let grid_size = MTLSize {
diff --git a/candle-metal-kernels/src/libMetalFlashAttention.metallib b/candle-metal-kernels/src/libMetalFlashAttention.metallib
index 5ed9d033..1e2d1acf 100644
--- a/candle-metal-kernels/src/libMetalFlashAttention.metallib
+++ b/candle-metal-kernels/src/libMetalFlashAttention.metallib
Binary files differ