summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorChristopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com>2024-01-29 15:31:10 +0000
committerGitHub <noreply@github.com>2024-01-29 15:31:10 +0000
commit6d83d42efb1c8126c4fc34faee3f5a139b09dec6 (patch)
tree5c9f35bc5aeb2f78aa8b6473932b479ef98cf39b /candle-metal-kernels
parentfd7c8565646039e35925b8730d27ddad195d7e73 (diff)
parentb6afb4660113b0633af372e7e10310377b677afd (diff)
downloadcandle-6d83d42efb1c8126c4fc34faee3f5a139b09dec6.tar.gz
candle-6d83d42efb1c8126c4fc34faee3f5a139b09dec6.tar.bz2
candle-6d83d42efb1c8126c4fc34faee3f5a139b09dec6.zip
Merge pull request #1606 from FL33TW00D/feature/larger-batches
fix: larger batches
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs13
-rw-r--r--candle-metal-kernels/src/libMetalFlashAttention.metallibbin102760 -> 116184 bytes
2 files changed, 6 insertions, 7 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index fe969372..2d27d230 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1364,13 +1364,12 @@ pub fn call_gemm(
// TODO byte_stride_d
let byte_stride_d = 0;
- let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
- for i in 0..b {
- buffer.push((i * byte_stride_a) as u64);
- buffer.push((i * byte_stride_b) as u64);
- buffer.push((i * byte_stride_c) as u64);
- buffer.push((i * byte_stride_d) as u64);
- }
+ let buffer: Vec<u64> = vec![
+ byte_stride_a as _,
+ byte_stride_b as _,
+ byte_stride_c as _,
+ byte_stride_d as _,
+ ];
encoder.set_bytes(
10,
(buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
diff --git a/candle-metal-kernels/src/libMetalFlashAttention.metallib b/candle-metal-kernels/src/libMetalFlashAttention.metallib
index f5116ca6..1e2d1acf 100644
--- a/candle-metal-kernels/src/libMetalFlashAttention.metallib
+++ b/candle-metal-kernels/src/libMetalFlashAttention.metallib
Binary files differ