diff options
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 13 | ||||
-rw-r--r-- | candle-metal-kernels/src/libMetalFlashAttention.metallib | bin | 102760 -> 116216 bytes |
2 files changed, 6 insertions, 7 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 201af97e..2773ca6a 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1359,13 +1359,12 @@ pub fn call_gemm( // TODO byte_stride_d let byte_stride_d = 0; - let mut buffer: Vec<u64> = Vec::with_capacity(b * 4); - for i in 0..b { - buffer.push((i * byte_stride_a) as u64); - buffer.push((i * byte_stride_b) as u64); - buffer.push((i * byte_stride_c) as u64); - buffer.push((i * byte_stride_d) as u64); - } + let buffer: Vec<u64> = vec![ + byte_stride_a as _, + byte_stride_b as _, + byte_stride_c as _, + byte_stride_d as _, + ]; encoder.set_bytes( 10, (buffer.len() * core::mem::size_of::<u64>()) as NSUInteger, diff --git a/candle-metal-kernels/src/libMetalFlashAttention.metallib b/candle-metal-kernels/src/libMetalFlashAttention.metallib Binary files differindex f5116ca6..c28d2b03 100644 --- a/candle-metal-kernels/src/libMetalFlashAttention.metallib +++ b/candle-metal-kernels/src/libMetalFlashAttention.metallib |