summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorEric Buehler <65165915+EricLBuehler@users.noreply.github.com>2025-01-16 05:30:10 -0500
committerGitHub <noreply@github.com>2025-01-16 11:30:10 +0100
commit17cbbe4286f25934197db79a244fd0694259c899 (patch)
treeb670f534eef86ca047f68f6c3c6a1e1386b197b6 /candle-nn
parent6fd2f63a15353ceaac674165d13d2241589382e0 (diff)
downloadcandle-17cbbe4286f25934197db79a244fd0694259c899.tar.gz
candle-17cbbe4286f25934197db79a244fd0694259c899.tar.bz2
candle-17cbbe4286f25934197db79a244fd0694259c899.zip
Sync upstream MLX sdpa vector kernels with mask (#2718)HEADmain
* Sync upstream mlx sdpa vector kernels with mask * Dispatch to the 2pass kernel * Format
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/ops.rs95
1 files changed, 74 insertions, 21 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index c84e297b..d7f88a0b 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -1074,27 +1074,80 @@ impl candle::CustomOp3 for Sdpa {
let command_buffer = q.device().command_buffer()?;
if supports_sdpa_vector {
- command_buffer.set_label("vector_attention");
- candle_metal_kernels::call_sdpa_vector(
- q.device().device(),
- &command_buffer,
- q.device().kernels(),
- q_l.start_offset(),
- q_l.dims(),
- q.buffer(),
- k_l.start_offset(),
- k_l.dims(),
- k_l.stride(),
- k.buffer(),
- v_l.start_offset(),
- v_l.stride(),
- v.buffer(),
- &output,
- self.scale,
- self.softcapping,
- itype,
- )
- .map_err(candle::Error::wrap)?;
+ // Route to the 2 pass fused attention if the k seqlen is large.
+ // https://github.com/ml-explore/mlx/pull/1597
+ const TWO_PASS_K_THRESHOLD: usize = 1024;
+ if k_l.dim(2)? >= TWO_PASS_K_THRESHOLD {
+ let mut intermediate_shape = [
+ &out_dims[0..out_dims.len() - 2],
+ &[candle_metal_kernels::SDPA_2PASS_BLOCKS],
+ &[out_dims[out_dims.len() - 1]],
+ ]
+ .concat();
+ let intermediate = device.new_buffer(
+ intermediate_shape.iter().product::<usize>(),
+ DType::F32,
+ "sdpa_2pass_intermediate",
+ )?;
+ let _ = intermediate_shape.pop().unwrap();
+ let sums = device.new_buffer(
+ intermediate_shape.iter().product::<usize>(),
+ DType::F32,
+ "sdpa_2pass_sums",
+ )?;
+ let maxs = device.new_buffer(
+ intermediate_shape.iter().product::<usize>(),
+ DType::F32,
+ "sdpa_2pass_maxs",
+ )?;
+
+ command_buffer.set_label("vector_attention");
+ candle_metal_kernels::call_sdpa_vector_2pass(
+ q.device().device(),
+ &command_buffer,
+ q.device().kernels(),
+ q_l.start_offset(),
+ q_l.dims(),
+ q.buffer(),
+ k_l.start_offset(),
+ k_l.dims(),
+ k_l.stride(),
+ k.buffer(),
+ v_l.start_offset(),
+ v_l.stride(),
+ v.buffer(),
+ &output,
+ &intermediate,
+ &sums,
+ &maxs,
+ self.scale,
+ self.softcapping,
+ itype,
+ )
+ .map_err(candle::Error::wrap)?;
+ } else {
+ command_buffer.set_label("vector_attention");
+ candle_metal_kernels::call_sdpa_vector(
+ q.device().device(),
+ &command_buffer,
+ q.device().kernels(),
+ q_l.start_offset(),
+ q_l.dims(),
+ q.buffer(),
+ k_l.start_offset(),
+ k_l.dims(),
+ k_l.stride(),
+ k.buffer(),
+ v_l.start_offset(),
+ v_l.stride(),
+ v.buffer(),
+ &output,
+ self.scale,
+ self.softcapping,
+ itype,
+ )
+ .map_err(candle::Error::wrap)?;
+ }
} else if supports_sdpa_full {
if q_l.dim(2)? != k_l.dim(2)? {
candle::bail!(