diff options
author | FL33TW00D <chris@fleetwood.dev> | 2024-01-22 15:15:19 +0000 |
---|---|---|
committer | FL33TW00D <chris@fleetwood.dev> | 2024-01-22 15:15:19 +0000 |
commit | b6afb4660113b0633af372e7e10310377b677afd (patch) | |
tree | 41289bd709888d39e3bfb1fe42a664bc2c207310 /candle-metal-kernels | |
parent | 73d79e609226cbe5def96726f8a1896cf4b3dc5d (diff) | |
download | candle-b6afb4660113b0633af372e7e10310377b677afd.tar.gz candle-b6afb4660113b0633af372e7e10310377b677afd.tar.bz2 candle-b6afb4660113b0633af372e7e10310377b677afd.zip |
chore: final
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 25 | ||||
-rw-r--r-- | candle-metal-kernels/src/libMetalFlashAttention.metallib | bin | 116168 -> 116184 bytes |
2 files changed, 10 insertions, 15 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 4c0f9223..2773ca6a 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,7 +1,6 @@ use metal::{ Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, - Device, Function, FunctionConstantValues, Library, MTLDataType, MTLResourceOptions, MTLSize, - NSUInteger, + Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger, }; use std::collections::HashMap; use std::ffi::c_void; @@ -1360,21 +1359,17 @@ pub fn call_gemm( // TODO byte_stride_d let byte_stride_d = 0; - let mut buffer: Vec<u64> = Vec::with_capacity(b * 4); - for i in 0..b { - buffer.push((i * byte_stride_a) as u64); - buffer.push((i * byte_stride_b) as u64); - buffer.push((i * byte_stride_c) as u64); - buffer.push((i * byte_stride_d) as u64); - } - - let matrix_offsets = device.new_buffer_with_data( - buffer.as_ptr() as *const c_void, + let buffer: Vec<u64> = vec![ + byte_stride_a as _, + byte_stride_b as _, + byte_stride_c as _, + byte_stride_d as _, + ]; + encoder.set_bytes( + 10, (buffer.len() * core::mem::size_of::<u64>()) as NSUInteger, - MTLResourceOptions::StorageModeManaged, + buffer.as_ptr() as *const NSUInteger as *const c_void, ); - encoder.set_buffer(10, Some(&matrix_offsets), 0); - encoder.use_resource(&matrix_offsets, metal::MTLResourceUsage::Read); } let grid_size = MTLSize { diff --git a/candle-metal-kernels/src/libMetalFlashAttention.metallib b/candle-metal-kernels/src/libMetalFlashAttention.metallib Binary files differindex 5ed9d033..1e2d1acf 100644 --- a/candle-metal-kernels/src/libMetalFlashAttention.metallib +++ b/candle-metal-kernels/src/libMetalFlashAttention.metallib |