diff options
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 188 | ||||
-rw-r--r-- | candle-metal-kernels/src/scaled_dot_product_attention.metal | 252 | ||||
-rw-r--r-- | candle-nn/src/ops.rs | 95 |
3 files changed, 486 insertions, 49 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 5f948cbf..818e4a02 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1906,7 +1906,12 @@ pub fn call_sdpa_vector( alpha }; - let pipeline = kernels.load_pipeline(device, Source::Sdpa, name)?; + let constants = Some(ConstantValues::new(vec![( + 20, + Value::Bool(/* sdpa_vector_has_mask */ false), + )])); + + let pipeline = kernels.load_pipeline_with_constants(device, Source::Sdpa, name, constants)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); @@ -1948,6 +1953,187 @@ pub fn call_sdpa_vector( Ok(()) } +pub const SDPA_2PASS_BLOCKS: usize = 32; + +/// SDPA vector 2pass is supported when: +/// - q head dim == 64, 96, 128 +/// - no mask +/// - q,k,v are contiguous +#[allow(clippy::too_many_arguments)] +pub fn call_sdpa_vector_2pass( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + q_offset: usize, + q_shape: &[usize], + q_buffer: &Buffer, + k_offset: usize, + k_shape: &[usize], + k_stride: &[usize], + k_buffer: &Buffer, + v_offset: usize, + v_stride: &[usize], + v_buffer: &Buffer, + output: &Buffer, + intermediate: &Buffer, + sums: &Buffer, + maxs: &Buffer, + alpha: f32, + softcapping: f32, + itype: SdpaDType, +) -> Result<(), MetalKernelError> { + let bk = q_shape.last().unwrap(); + + // First pass + { + let name_pass1 = match (bk, itype) { + (32, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_32", + (64, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_64", + (96, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_96", + (128, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_128", + (256, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_256", + (32, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_32", + (64, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_64", + (96, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_96", + (128, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_128", + (256, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_256", + (32, SdpaDType::F32) => "sdpa_vector_2pass_1_float_32", + (64, SdpaDType::F32) => "sdpa_vector_2pass_1_float_64", + (96, SdpaDType::F32) => "sdpa_vector_2pass_1_float_96", + (128, SdpaDType::F32) => "sdpa_vector_2pass_1_float_128", + (256, SdpaDType::F32) => "sdpa_vector_2pass_1_float_256", + (other, _) => { + return Err(MetalKernelError::SdpaHeadSizeMismatch { + variation: "vector_2pass_1", + got: *other, + expected: vec![32, 64, 96, 128, 256], + }) + } + }; + + let gqa_factor = (q_shape[1] / k_shape[1]) as i32; + let n = k_shape[2] as i32; + let b = (q_shape[0] * q_shape[1]) as i32; + let kstride = k_stride[1]; + let vstride = v_stride[1]; + + let alpha = if softcapping != 1. { + alpha / softcapping + } else { + alpha + }; + + let constants = Some(ConstantValues::new(vec![( + 20, + Value::Bool(/* sdpa_vector_has_mask */ false), + )])); + + let pipeline = + kernels.load_pipeline_with_constants(device, Source::Sdpa, &name_pass1, constants)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + // q = (bs, qhead, seq, hidden) + // k/v = (bs, kv_head, kv_seq, hidden) + + set_params!( + encoder, + ( + (q_buffer, q_offset), + (k_buffer, k_offset), + (v_buffer, v_offset), + intermediate, + sums, + maxs, + gqa_factor, + n, + kstride, + vstride, + alpha, + softcapping + ) + ); + + let grid_dims = MTLSize { + width: 1, + height: b as u64, + depth: SDPA_2PASS_BLOCKS as u64, + }; + let group_dims = MTLSize { + width: 8 * 32, + height: 1, + depth: 1, + }; + encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(intermediate, metal::MTLResourceUsage::Write); + encoder.use_resource(sums, metal::MTLResourceUsage::Write); + encoder.use_resource(maxs, metal::MTLResourceUsage::Write); + + encoder.dispatch_thread_groups(grid_dims, group_dims); + } + + // Final pass + { + let name_pass2 = match (bk, itype) { + (32, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_32", + (64, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_64", + (96, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_96", + (128, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_128", + (256, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_256", + (32, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_32", + (64, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_64", + (96, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_96", + (128, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_128", + (256, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_256", + (32, SdpaDType::F32) => "sdpa_vector_2pass_2_float_32", + (64, SdpaDType::F32) => "sdpa_vector_2pass_2_float_64", + (96, SdpaDType::F32) => "sdpa_vector_2pass_2_float_96", + (128, SdpaDType::F32) => "sdpa_vector_2pass_2_float_128", + (256, SdpaDType::F32) => "sdpa_vector_2pass_2_float_256", + (other, _) => { + return Err(MetalKernelError::SdpaHeadSizeMismatch { + variation: "vector_2pass_2", + got: *other, + expected: vec![32, 64, 96, 128, 256], + }) + } + }; + + let b = (q_shape[0] * q_shape[1]) as i32; + + let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name_pass2)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + // q = (bs, qhead, seq, hidden) + // k/v = (bs, kv_head, kv_seq, hidden) + + set_params!(encoder, (intermediate, sums, maxs, output)); + + let grid_dims = MTLSize { + width: 1, + height: b as u64, + depth: 1, + }; + let group_dims = MTLSize { + width: 1024, + height: 1, + depth: 1, + }; + encoder.use_resource(intermediate, metal::MTLResourceUsage::Write); + encoder.use_resource(sums, metal::MTLResourceUsage::Write); + encoder.use_resource(maxs, metal::MTLResourceUsage::Write); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + + encoder.dispatch_thread_groups(grid_dims, group_dims); + } + Ok(()) +} + #[allow(clippy::too_many_arguments)] pub fn call_im2col1d_strided( device: &Device, diff --git a/candle-metal-kernels/src/scaled_dot_product_attention.metal b/candle-metal-kernels/src/scaled_dot_product_attention.metal index 1abb9f08..0453e0d1 100644 --- a/candle-metal-kernels/src/scaled_dot_product_attention.metal +++ b/candle-metal-kernels/src/scaled_dot_product_attention.metal @@ -47,6 +47,8 @@ struct MLXScaledDotProductAttentionParams { // ============ "mlx/backend/metal/kernels/scaled_dot_product_attention_params.sdpa_vector" +constant bool sdpa_vector_has_mask [[function_constant(20)]]; + template <typename T, int D> [[kernel]] void sdpa_vector( const device T* queries [[buffer(0)]], @@ -59,14 +61,16 @@ template <typename T, int D> const constant size_t& v_stride, const constant float& scale, const constant float& softcapping, + const device bool* mask [[function_constant(sdpa_vector_has_mask)]], + const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], + const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int BN = 32; constexpr int BD = 32; constexpr int elem_per_thread = D / BD; - - const int stride = BN * D; + constexpr int stride = BN * D; typedef float U; @@ -84,6 +88,9 @@ template <typename T, int D> queries += head_idx * D + simd_lid * elem_per_thread; keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread; + if (sdpa_vector_has_mask) { + mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride; + } out += head_idx * D + simd_gid * elem_per_thread; // Read the query and 0 the output accumulator @@ -99,40 +106,41 @@ template <typename T, int D> // For each key for (int i = simd_gid; i < N; i += BN) { - // Read the key - for (int i = 0; i < elem_per_thread; i++) { - k[i] = keys[i]; - } + if (!sdpa_vector_has_mask || mask[0]) { + // Read the key + for (int j = 0; j < elem_per_thread; j++) { + k[j] = keys[j]; + } - // Compute the i-th score - U score = 0; - for (int i = 0; i < elem_per_thread; i++) { - score += q[i] * k[i]; - } - score = simd_sum(score); - if (softcapping != 1.) { - score = precise::tanh(score); - score = score * softcapping; - } + // Compute the i-th score + U score = 0; + for (int j = 0; j < elem_per_thread; j++) { + score += q[j] * k[j]; + } + score = simd_sum(score); + if (softcapping != 1.) { + score = precise::tanh(score); + score = score * softcapping; + } - // Update the accumulators - U new_max = max(max_score, score); - U factor = fast::exp(max_score - new_max); - U exp_score = fast::exp(score - new_max); + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); - max_score = new_max; - sum_exp_score = sum_exp_score * factor + exp_score; + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; - // Update the output accumulator - for (int i = 0; i < elem_per_thread; i++) { - o[i] = o[i] * factor + exp_score * values[i]; + // Update the output accumulator + for (int j = 0; j < elem_per_thread; j++) { + o[j] = o[j] * factor + exp_score * values[j]; + } } // Move the pointers to the next kv keys += stride; values += stride; } - threadgroup_barrier(mem_flags::mem_threadgroup); // Each thread has a partial part of the output so we need to combine them. @@ -163,6 +171,164 @@ template <typename T, int D> } } +template <typename T, int D> +[[kernel]] void sdpa_vector_2pass_1( + const device T* queries [[buffer(0)]], + const device T* keys [[buffer(1)]], + const device T* values [[buffer(2)]], + device float* out [[buffer(3)]], + device float* sums [[buffer(4)]], + device float* maxs [[buffer(5)]], + const constant int& gqa_factor, + const constant int& N, + const constant size_t& k_stride, + const constant size_t& v_stride, + const constant float& scale, + const constant float& softcapping, + const device bool* mask [[function_constant(sdpa_vector_has_mask)]], + const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], + const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int BN = 8; + constexpr int BD = 32; + constexpr int elem_per_thread = D / BD; + constexpr int stride = BN * D; + constexpr int blocks = 32; + + typedef float U; + + thread U q[elem_per_thread]; + thread U k[elem_per_thread]; + thread U o[elem_per_thread]; + + threadgroup U outputs[BN * BD]; + threadgroup U max_scores[BN]; + threadgroup U sum_exp_scores[BN]; + + // Adjust positions + const int block_idx = tid.z; + const int head_idx = tid.y; + const int kv_head_idx = head_idx / gqa_factor; + queries += head_idx * D + simd_lid * elem_per_thread; + keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D + + simd_lid * elem_per_thread; + values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D + + simd_lid * elem_per_thread; + out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread; + if (sdpa_vector_has_mask) { + mask += head_idx * mask_head_stride + + (block_idx * BN + simd_gid) * mask_seq_stride; + } + sums += head_idx * blocks + block_idx; + maxs += head_idx * blocks + block_idx; + + // Read the query and 0 the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + q[i] = static_cast<U>(scale) * queries[i]; + } + for (int i = 0; i < elem_per_thread; i++) { + o[i] = 0; + } + + U max_score = -1e9; + U sum_exp_score = 0; + + // For each key + for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { + if (!sdpa_vector_has_mask || mask[0]) { + // Read the key + for (int i = 0; i < elem_per_thread; i++) { + k[i] = keys[i]; + } + + // Compute the i-th score + U score = 0; + for (int i = 0; i < elem_per_thread; i++) { + score += q[i] * k[i]; + } + score = simd_sum(score); + if (softcapping != 1.) { + score = precise::tanh(score); + score = score * softcapping; + } + + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + // Update the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + o[i] = o[i] * factor + exp_score * values[i]; + } + } + + // Move the pointers to the next kv + keys += blocks * stride; + values += blocks * stride; + if (sdpa_vector_has_mask) { + mask += BN * blocks * mask_seq_stride; + } + } +} + +template <typename T, int D> +[[kernel]] void sdpa_vector_2pass_2( + const device float* partials [[buffer(0)]], + const device float* sums [[buffer(1)]], + const device float* maxs [[buffer(2)]], + device T* out [[buffer(3)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int BN = 32; + constexpr int BD = 32; + constexpr int elem_per_thread = D / BD; + constexpr int blocks = 32; + + typedef float U; + + thread U o[elem_per_thread]; + threadgroup U outputs[BN * BD]; + + // Adjust positions + const int head_idx = tid.y; + partials += head_idx * blocks * D + simd_gid * D + simd_lid * elem_per_thread; + sums += head_idx * blocks; + maxs += head_idx * blocks; + out += head_idx * D + simd_gid * elem_per_thread; + + // First everybody reads the max and sum_exp + U max_score = maxs[simd_lid]; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + U sum_exp_score = simd_sum(sums[simd_lid] * factor); + + // Now read the block into registers and then use shared memory to transpose + // it + for (int i = 0; i < elem_per_thread; i++) { + o[i] = partials[i]; + } + for (int i = 0; i < elem_per_thread; i++) { + outputs[simd_lid * BD + simd_gid] = o[i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // And write the output + if (simd_lid == 0) { + for (int i = 0; i < elem_per_thread; i++) { + out[i] = static_cast<T>(o[i]); + } + } +} + // ============ "mlx/backend/metal/kernels/steel/defines.h" #define STEEL_CONST static constant constexpr const @@ -1238,9 +1404,41 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2); const constant size_t& v_stride, \ const constant float& scale, \ const constant float& softcapping, \ + const device bool* mask [[function_constant(sdpa_vector_has_mask)]],, \ + const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], \ + const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); \ + template [[host_name("sdpa_vector_2pass_1_" #type "_" #head_dim)]] \ + [[kernel]] void sdpa_vector_2pass_1<type, head_dim>( \ + const device type* queries [[buffer(0)]], \ + const device type* keys [[buffer(1)]], \ + const device type* values [[buffer(2)]], \ + device float* out [[buffer(3)]], \ + device float* sums [[buffer(4)]], \ + device float* maxs [[buffer(5)]], \ + const constant int& gqa_factor, \ + const constant int& N, \ + const constant size_t& k_stride, \ + const constant size_t& v_stride, \ + const constant float& scale, \ + const constant float& softcapping, \ + const device bool* mask [[function_constant(sdpa_vector_has_mask)]],, \ + const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], \ + const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); \ + template [[host_name("sdpa_vector_2pass_2_" #type "_" #head_dim)]] \ + [[kernel]] void sdpa_vector_2pass_2<type, head_dim>( \ + const device float* partials [[buffer(0)]], \ + const device float* sums [[buffer(1)]], \ + const device float* maxs [[buffer(2)]], \ + device type* out [[buffer(3)]], \ uint3 tid [[threadgroup_position_in_grid]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); + uint simd_lid [[thread_index_in_simdgroup]]); \ #define instantiate_sdpa_vector_heads(type) \ instantiate_sdpa_vector(type, 32) \ diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index c84e297b..d7f88a0b 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -1074,27 +1074,80 @@ impl candle::CustomOp3 for Sdpa { let command_buffer = q.device().command_buffer()?; if supports_sdpa_vector { - command_buffer.set_label("vector_attention"); - candle_metal_kernels::call_sdpa_vector( - q.device().device(), - &command_buffer, - q.device().kernels(), - q_l.start_offset(), - q_l.dims(), - q.buffer(), - k_l.start_offset(), - k_l.dims(), - k_l.stride(), - k.buffer(), - v_l.start_offset(), - v_l.stride(), - v.buffer(), - &output, - self.scale, - self.softcapping, - itype, - ) - .map_err(candle::Error::wrap)?; + // Route to the 2 pass fused attention if the k seqlen is large. + // https://github.com/ml-explore/mlx/pull/1597 + const TWO_PASS_K_THRESHOLD: usize = 1024; + if k_l.dim(2)? >= TWO_PASS_K_THRESHOLD { + let mut intermediate_shape = [ + &out_dims[0..out_dims.len() - 2], + &[candle_metal_kernels::SDPA_2PASS_BLOCKS], + &[out_dims[out_dims.len() - 1]], + ] + .concat(); + let intermediate = device.new_buffer( + intermediate_shape.iter().product::<usize>(), + DType::F32, + "sdpa_2pass_intermediate", + )?; + let _ = intermediate_shape.pop().unwrap(); + let sums = device.new_buffer( + intermediate_shape.iter().product::<usize>(), + DType::F32, + "sdpa_2pass_sums", + )?; + let maxs = device.new_buffer( + intermediate_shape.iter().product::<usize>(), + DType::F32, + "sdpa_2pass_maxs", + )?; + + command_buffer.set_label("vector_attention"); + candle_metal_kernels::call_sdpa_vector_2pass( + q.device().device(), + &command_buffer, + q.device().kernels(), + q_l.start_offset(), + q_l.dims(), + q.buffer(), + k_l.start_offset(), + k_l.dims(), + k_l.stride(), + k.buffer(), + v_l.start_offset(), + v_l.stride(), + v.buffer(), + &output, + &intermediate, + &sums, + &maxs, + self.scale, + self.softcapping, + itype, + ) + .map_err(candle::Error::wrap)?; + } else { + command_buffer.set_label("vector_attention"); + candle_metal_kernels::call_sdpa_vector( + q.device().device(), + &command_buffer, + q.device().kernels(), + q_l.start_offset(), + q_l.dims(), + q.buffer(), + k_l.start_offset(), + k_l.dims(), + k_l.stride(), + k.buffer(), + v_l.start_offset(), + v_l.stride(), + v.buffer(), + &output, + self.scale, + self.softcapping, + itype, + ) + .map_err(candle::Error::wrap)?; + } } else if supports_sdpa_full { if q_l.dim(2)? != k_l.dim(2)? { candle::bail!( |