diff options
author | Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> | 2025-01-16 05:30:10 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-01-16 11:30:10 +0100 |
commit | 17cbbe4286f25934197db79a244fd0694259c899 (patch) | |
tree | b670f534eef86ca047f68f6c3c6a1e1386b197b6 /candle-nn | |
parent | 6fd2f63a15353ceaac674165d13d2241589382e0 (diff) | |
download | candle-17cbbe4286f25934197db79a244fd0694259c899.tar.gz candle-17cbbe4286f25934197db79a244fd0694259c899.tar.bz2 candle-17cbbe4286f25934197db79a244fd0694259c899.zip |
* Sync upstream mlx sdpa vector kernels with mask
* Dispatch to the 2pass kernel
* Format
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/ops.rs | 95 |
1 files changed, 74 insertions, 21 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index c84e297b..d7f88a0b 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -1074,27 +1074,80 @@ impl candle::CustomOp3 for Sdpa { let command_buffer = q.device().command_buffer()?; if supports_sdpa_vector { - command_buffer.set_label("vector_attention"); - candle_metal_kernels::call_sdpa_vector( - q.device().device(), - &command_buffer, - q.device().kernels(), - q_l.start_offset(), - q_l.dims(), - q.buffer(), - k_l.start_offset(), - k_l.dims(), - k_l.stride(), - k.buffer(), - v_l.start_offset(), - v_l.stride(), - v.buffer(), - &output, - self.scale, - self.softcapping, - itype, - ) - .map_err(candle::Error::wrap)?; + // Route to the 2 pass fused attention if the k seqlen is large. + // https://github.com/ml-explore/mlx/pull/1597 + const TWO_PASS_K_THRESHOLD: usize = 1024; + if k_l.dim(2)? >= TWO_PASS_K_THRESHOLD { + let mut intermediate_shape = [ + &out_dims[0..out_dims.len() - 2], + &[candle_metal_kernels::SDPA_2PASS_BLOCKS], + &[out_dims[out_dims.len() - 1]], + ] + .concat(); + let intermediate = device.new_buffer( + intermediate_shape.iter().product::<usize>(), + DType::F32, + "sdpa_2pass_intermediate", + )?; + let _ = intermediate_shape.pop().unwrap(); + let sums = device.new_buffer( + intermediate_shape.iter().product::<usize>(), + DType::F32, + "sdpa_2pass_sums", + )?; + let maxs = device.new_buffer( + intermediate_shape.iter().product::<usize>(), + DType::F32, + "sdpa_2pass_maxs", + )?; + + command_buffer.set_label("vector_attention"); + candle_metal_kernels::call_sdpa_vector_2pass( + q.device().device(), + &command_buffer, + q.device().kernels(), + q_l.start_offset(), + q_l.dims(), + q.buffer(), + k_l.start_offset(), + k_l.dims(), + k_l.stride(), + k.buffer(), + v_l.start_offset(), + v_l.stride(), + v.buffer(), + &output, + &intermediate, + &sums, + &maxs, + self.scale, + self.softcapping, + itype, + ) + .map_err(candle::Error::wrap)?; + } else { + command_buffer.set_label("vector_attention"); + candle_metal_kernels::call_sdpa_vector( + q.device().device(), + &command_buffer, + q.device().kernels(), + q_l.start_offset(), + q_l.dims(), + q.buffer(), + k_l.start_offset(), + k_l.dims(), + k_l.stride(), + k.buffer(), + v_l.start_offset(), + v_l.stride(), + v.buffer(), + &output, + self.scale, + self.softcapping, + itype, + ) + .map_err(candle::Error::wrap)?; + } } else if supports_sdpa_full { if q_l.dim(2)? != k_l.dim(2)? { candle::bail!( |