summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorFL33TW00D <chris@fleetwood.dev>2024-01-18 14:30:14 +0000
committerFL33TW00D <chris@fleetwood.dev>2024-01-18 14:30:14 +0000
commit4f79f5df8aef3ee7e6e4757ffe39af4adab6a84d (patch)
treecedbc84d17373ba44925c7ae1880865c39149d2c /candle-metal-kernels
parent1cf34368b7d10600ca2fac197cd49c8bad2f6ad1 (diff)
downloadcandle-4f79f5df8aef3ee7e6e4757ffe39af4adab6a84d.tar.gz
candle-4f79f5df8aef3ee7e6e4757ffe39af4adab6a84d.tar.bz2
candle-4f79f5df8aef3ee7e6e4757ffe39af4adab6a84d.zip
fix: larger batches
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs13
-rw-r--r--candle-metal-kernels/src/libMetalFlashAttention.metallibbin102760 -> 116216 bytes
2 files changed, 6 insertions, 7 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 201af97e..2773ca6a 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1359,13 +1359,12 @@ pub fn call_gemm(
// TODO byte_stride_d
let byte_stride_d = 0;
- let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
- for i in 0..b {
- buffer.push((i * byte_stride_a) as u64);
- buffer.push((i * byte_stride_b) as u64);
- buffer.push((i * byte_stride_c) as u64);
- buffer.push((i * byte_stride_d) as u64);
- }
+ let buffer: Vec<u64> = vec![
+ byte_stride_a as _,
+ byte_stride_b as _,
+ byte_stride_c as _,
+ byte_stride_d as _,
+ ];
encoder.set_bytes(
10,
(buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
diff --git a/candle-metal-kernels/src/libMetalFlashAttention.metallib b/candle-metal-kernels/src/libMetalFlashAttention.metallib
index f5116ca6..c28d2b03 100644
--- a/candle-metal-kernels/src/libMetalFlashAttention.metallib
+++ b/candle-metal-kernels/src/libMetalFlashAttention.metallib
Binary files differ