diff options
author | FL33TW00D <chris@fleetwood.dev> | 2024-01-18 14:30:14 +0000 |
---|---|---|
committer | FL33TW00D <chris@fleetwood.dev> | 2024-01-18 14:30:14 +0000 |
commit | 4f79f5df8aef3ee7e6e4757ffe39af4adab6a84d (patch) | |
tree | cedbc84d17373ba44925c7ae1880865c39149d2c /candle-metal-kernels | |
parent | 1cf34368b7d10600ca2fac197cd49c8bad2f6ad1 (diff) | |
download | candle-4f79f5df8aef3ee7e6e4757ffe39af4adab6a84d.tar.gz candle-4f79f5df8aef3ee7e6e4757ffe39af4adab6a84d.tar.bz2 candle-4f79f5df8aef3ee7e6e4757ffe39af4adab6a84d.zip |
fix: larger batches
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 13 | ||||
-rw-r--r-- | candle-metal-kernels/src/libMetalFlashAttention.metallib | bin | 102760 -> 116216 bytes |
2 files changed, 6 insertions, 7 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 201af97e..2773ca6a 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1359,13 +1359,12 @@ pub fn call_gemm( // TODO byte_stride_d let byte_stride_d = 0; - let mut buffer: Vec<u64> = Vec::with_capacity(b * 4); - for i in 0..b { - buffer.push((i * byte_stride_a) as u64); - buffer.push((i * byte_stride_b) as u64); - buffer.push((i * byte_stride_c) as u64); - buffer.push((i * byte_stride_d) as u64); - } + let buffer: Vec<u64> = vec![ + byte_stride_a as _, + byte_stride_b as _, + byte_stride_c as _, + byte_stride_d as _, + ]; encoder.set_bytes( 10, (buffer.len() * core::mem::size_of::<u64>()) as NSUInteger, diff --git a/candle-metal-kernels/src/libMetalFlashAttention.metallib b/candle-metal-kernels/src/libMetalFlashAttention.metallib Binary files differindex f5116ca6..c28d2b03 100644 --- a/candle-metal-kernels/src/libMetalFlashAttention.metallib +++ b/candle-metal-kernels/src/libMetalFlashAttention.metallib |