summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorFL33TW00D <chris@fleetwood.dev>2024-01-19 08:57:49 +0000
committerFL33TW00D <chris@fleetwood.dev>2024-01-19 08:57:49 +0000
commitb1879f17f6b9d13e101a4d3ff5b6b4ff2e1a7a24 (patch)
tree1daf6614ee1d95845dbe8408cddce289933d12f1 /candle-metal-kernels
parent4f79f5df8aef3ee7e6e4757ffe39af4adab6a84d (diff)
downloadcandle-b1879f17f6b9d13e101a4d3ff5b6b4ff2e1a7a24.tar.gz
candle-b1879f17f6b9d13e101a4d3ff5b6b4ff2e1a7a24.tar.bz2
candle-b1879f17f6b9d13e101a4d3ff5b6b4ff2e1a7a24.zip
chore: switch to buffer
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs24
-rw-r--r--candle-metal-kernels/src/libMetalFlashAttention.metallibbin116216 -> 102760 bytes
2 files changed, 14 insertions, 10 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 2773ca6a..8cb3c16a 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1,6 +1,7 @@
use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
- Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
+ Device, Function, FunctionConstantValues, Library, MTLDataType, MTLResourceOptions, MTLSize,
+ NSUInteger,
};
use std::collections::HashMap;
use std::ffi::c_void;
@@ -1359,17 +1360,20 @@ pub fn call_gemm(
// TODO byte_stride_d
let byte_stride_d = 0;
- let buffer: Vec<u64> = vec![
- byte_stride_a as _,
- byte_stride_b as _,
- byte_stride_c as _,
- byte_stride_d as _,
- ];
- encoder.set_bytes(
- 10,
- (buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
+ let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
+ for i in 0..b {
+ buffer.push((i * byte_stride_a) as u64);
+ buffer.push((i * byte_stride_b) as u64);
+ buffer.push((i * byte_stride_c) as u64);
+ buffer.push((i * byte_stride_d) as u64);
+ }
+
+ let matrix_offsets = device.new_buffer_with_data(
buffer.as_ptr() as *const NSUInteger as *const c_void,
+ (buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
+ MTLResourceOptions::StorageModePrivate,
);
+ encoder.set_buffer(10, Some(&matrix_offsets), 0);
}
let grid_size = MTLSize {
diff --git a/candle-metal-kernels/src/libMetalFlashAttention.metallib b/candle-metal-kernels/src/libMetalFlashAttention.metallib
index c28d2b03..f5116ca6 100644
--- a/candle-metal-kernels/src/libMetalFlashAttention.metallib
+++ b/candle-metal-kernels/src/libMetalFlashAttention.metallib
Binary files differ