From 17cbbe4286f25934197db79a244fd0694259c899 Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Thu, 16 Jan 2025 05:30:10 -0500 Subject: Sync upstream MLX sdpa vector kernels with mask (#2718) * Sync upstream mlx sdpa vector kernels with mask * Dispatch to the 2pass kernel * Format --- candle-nn/src/ops.rs | 95 ++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 74 insertions(+), 21 deletions(-) (limited to 'candle-nn') diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index c84e297b..d7f88a0b 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -1074,27 +1074,80 @@ impl candle::CustomOp3 for Sdpa { let command_buffer = q.device().command_buffer()?; if supports_sdpa_vector { - command_buffer.set_label("vector_attention"); - candle_metal_kernels::call_sdpa_vector( - q.device().device(), - &command_buffer, - q.device().kernels(), - q_l.start_offset(), - q_l.dims(), - q.buffer(), - k_l.start_offset(), - k_l.dims(), - k_l.stride(), - k.buffer(), - v_l.start_offset(), - v_l.stride(), - v.buffer(), - &output, - self.scale, - self.softcapping, - itype, - ) - .map_err(candle::Error::wrap)?; + // Route to the 2 pass fused attention if the k seqlen is large. + // https://github.com/ml-explore/mlx/pull/1597 + const TWO_PASS_K_THRESHOLD: usize = 1024; + if k_l.dim(2)? >= TWO_PASS_K_THRESHOLD { + let mut intermediate_shape = [ + &out_dims[0..out_dims.len() - 2], + &[candle_metal_kernels::SDPA_2PASS_BLOCKS], + &[out_dims[out_dims.len() - 1]], + ] + .concat(); + let intermediate = device.new_buffer( + intermediate_shape.iter().product::(), + DType::F32, + "sdpa_2pass_intermediate", + )?; + let _ = intermediate_shape.pop().unwrap(); + let sums = device.new_buffer( + intermediate_shape.iter().product::(), + DType::F32, + "sdpa_2pass_sums", + )?; + let maxs = device.new_buffer( + intermediate_shape.iter().product::(), + DType::F32, + "sdpa_2pass_maxs", + )?; + + command_buffer.set_label("vector_attention"); + candle_metal_kernels::call_sdpa_vector_2pass( + q.device().device(), + &command_buffer, + q.device().kernels(), + q_l.start_offset(), + q_l.dims(), + q.buffer(), + k_l.start_offset(), + k_l.dims(), + k_l.stride(), + k.buffer(), + v_l.start_offset(), + v_l.stride(), + v.buffer(), + &output, + &intermediate, + &sums, + &maxs, + self.scale, + self.softcapping, + itype, + ) + .map_err(candle::Error::wrap)?; + } else { + command_buffer.set_label("vector_attention"); + candle_metal_kernels::call_sdpa_vector( + q.device().device(), + &command_buffer, + q.device().kernels(), + q_l.start_offset(), + q_l.dims(), + q.buffer(), + k_l.start_offset(), + k_l.dims(), + k_l.stride(), + k.buffer(), + v_l.start_offset(), + v_l.stride(), + v.buffer(), + &output, + self.scale, + self.softcapping, + itype, + ) + .map_err(candle::Error::wrap)?; + } } else if supports_sdpa_full { if q_l.dim(2)? != k_l.dim(2)? { candle::bail!( -- cgit v1.2.3