diff options
author | FL33TW00D <chris@fleetwood.dev> | 2024-01-19 08:57:49 +0000 |
---|---|---|
committer | FL33TW00D <chris@fleetwood.dev> | 2024-01-19 08:57:49 +0000 |
commit | b1879f17f6b9d13e101a4d3ff5b6b4ff2e1a7a24 (patch) | |
tree | 1daf6614ee1d95845dbe8408cddce289933d12f1 /candle-metal-kernels | |
parent | 4f79f5df8aef3ee7e6e4757ffe39af4adab6a84d (diff) | |
download | candle-b1879f17f6b9d13e101a4d3ff5b6b4ff2e1a7a24.tar.gz candle-b1879f17f6b9d13e101a4d3ff5b6b4ff2e1a7a24.tar.bz2 candle-b1879f17f6b9d13e101a4d3ff5b6b4ff2e1a7a24.zip |
chore: switch to buffer
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 24 | ||||
-rw-r--r-- | candle-metal-kernels/src/libMetalFlashAttention.metallib | bin | 116216 -> 102760 bytes |
2 files changed, 14 insertions, 10 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 2773ca6a..8cb3c16a 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,6 +1,7 @@ use metal::{ Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, - Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger, + Device, Function, FunctionConstantValues, Library, MTLDataType, MTLResourceOptions, MTLSize, + NSUInteger, }; use std::collections::HashMap; use std::ffi::c_void; @@ -1359,17 +1360,20 @@ pub fn call_gemm( // TODO byte_stride_d let byte_stride_d = 0; - let buffer: Vec<u64> = vec![ - byte_stride_a as _, - byte_stride_b as _, - byte_stride_c as _, - byte_stride_d as _, - ]; - encoder.set_bytes( - 10, - (buffer.len() * core::mem::size_of::<u64>()) as NSUInteger, + let mut buffer: Vec<u64> = Vec::with_capacity(b * 4); + for i in 0..b { + buffer.push((i * byte_stride_a) as u64); + buffer.push((i * byte_stride_b) as u64); + buffer.push((i * byte_stride_c) as u64); + buffer.push((i * byte_stride_d) as u64); + } + + let matrix_offsets = device.new_buffer_with_data( buffer.as_ptr() as *const NSUInteger as *const c_void, + (buffer.len() * core::mem::size_of::<u64>()) as NSUInteger, + MTLResourceOptions::StorageModePrivate, ); + encoder.set_buffer(10, Some(&matrix_offsets), 0); } let grid_size = MTLSize { diff --git a/candle-metal-kernels/src/libMetalFlashAttention.metallib b/candle-metal-kernels/src/libMetalFlashAttention.metallib Binary files differindex c28d2b03..f5116ca6 100644 --- a/candle-metal-kernels/src/libMetalFlashAttention.metallib +++ b/candle-metal-kernels/src/libMetalFlashAttention.metallib |