summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-metal-kernels/src/lib.rs188
-rw-r--r--candle-metal-kernels/src/scaled_dot_product_attention.metal252
-rw-r--r--candle-nn/src/ops.rs95
3 files changed, 486 insertions, 49 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 5f948cbf..818e4a02 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1906,7 +1906,12 @@ pub fn call_sdpa_vector(
alpha
};
- let pipeline = kernels.load_pipeline(device, Source::Sdpa, name)?;
+ let constants = Some(ConstantValues::new(vec![(
+ 20,
+ Value::Bool(/* sdpa_vector_has_mask */ false),
+ )]));
+
+ let pipeline = kernels.load_pipeline_with_constants(device, Source::Sdpa, name, constants)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
@@ -1948,6 +1953,187 @@ pub fn call_sdpa_vector(
Ok(())
}
+pub const SDPA_2PASS_BLOCKS: usize = 32;
+
+/// SDPA vector 2pass is supported when:
+/// - q head dim == 64, 96, 128
+/// - no mask
+/// - q,k,v are contiguous
+#[allow(clippy::too_many_arguments)]
+pub fn call_sdpa_vector_2pass(
+ device: &Device,
+ ep: impl EncoderProvider,
+ kernels: &Kernels,
+ q_offset: usize,
+ q_shape: &[usize],
+ q_buffer: &Buffer,
+ k_offset: usize,
+ k_shape: &[usize],
+ k_stride: &[usize],
+ k_buffer: &Buffer,
+ v_offset: usize,
+ v_stride: &[usize],
+ v_buffer: &Buffer,
+ output: &Buffer,
+ intermediate: &Buffer,
+ sums: &Buffer,
+ maxs: &Buffer,
+ alpha: f32,
+ softcapping: f32,
+ itype: SdpaDType,
+) -> Result<(), MetalKernelError> {
+ let bk = q_shape.last().unwrap();
+
+ // First pass
+ {
+ let name_pass1 = match (bk, itype) {
+ (32, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_32",
+ (64, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_64",
+ (96, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_96",
+ (128, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_128",
+ (256, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_256",
+ (32, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_32",
+ (64, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_64",
+ (96, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_96",
+ (128, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_128",
+ (256, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_256",
+ (32, SdpaDType::F32) => "sdpa_vector_2pass_1_float_32",
+ (64, SdpaDType::F32) => "sdpa_vector_2pass_1_float_64",
+ (96, SdpaDType::F32) => "sdpa_vector_2pass_1_float_96",
+ (128, SdpaDType::F32) => "sdpa_vector_2pass_1_float_128",
+ (256, SdpaDType::F32) => "sdpa_vector_2pass_1_float_256",
+ (other, _) => {
+ return Err(MetalKernelError::SdpaHeadSizeMismatch {
+ variation: "vector_2pass_1",
+ got: *other,
+ expected: vec![32, 64, 96, 128, 256],
+ })
+ }
+ };
+
+ let gqa_factor = (q_shape[1] / k_shape[1]) as i32;
+ let n = k_shape[2] as i32;
+ let b = (q_shape[0] * q_shape[1]) as i32;
+ let kstride = k_stride[1];
+ let vstride = v_stride[1];
+
+ let alpha = if softcapping != 1. {
+ alpha / softcapping
+ } else {
+ alpha
+ };
+
+ let constants = Some(ConstantValues::new(vec![(
+ 20,
+ Value::Bool(/* sdpa_vector_has_mask */ false),
+ )]));
+
+ let pipeline =
+ kernels.load_pipeline_with_constants(device, Source::Sdpa, &name_pass1, constants)?;
+ let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ // q = (bs, qhead, seq, hidden)
+ // k/v = (bs, kv_head, kv_seq, hidden)
+
+ set_params!(
+ encoder,
+ (
+ (q_buffer, q_offset),
+ (k_buffer, k_offset),
+ (v_buffer, v_offset),
+ intermediate,
+ sums,
+ maxs,
+ gqa_factor,
+ n,
+ kstride,
+ vstride,
+ alpha,
+ softcapping
+ )
+ );
+
+ let grid_dims = MTLSize {
+ width: 1,
+ height: b as u64,
+ depth: SDPA_2PASS_BLOCKS as u64,
+ };
+ let group_dims = MTLSize {
+ width: 8 * 32,
+ height: 1,
+ depth: 1,
+ };
+ encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(intermediate, metal::MTLResourceUsage::Write);
+ encoder.use_resource(sums, metal::MTLResourceUsage::Write);
+ encoder.use_resource(maxs, metal::MTLResourceUsage::Write);
+
+ encoder.dispatch_thread_groups(grid_dims, group_dims);
+ }
+
+ // Final pass
+ {
+ let name_pass2 = match (bk, itype) {
+ (32, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_32",
+ (64, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_64",
+ (96, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_96",
+ (128, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_128",
+ (256, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_256",
+ (32, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_32",
+ (64, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_64",
+ (96, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_96",
+ (128, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_128",
+ (256, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_256",
+ (32, SdpaDType::F32) => "sdpa_vector_2pass_2_float_32",
+ (64, SdpaDType::F32) => "sdpa_vector_2pass_2_float_64",
+ (96, SdpaDType::F32) => "sdpa_vector_2pass_2_float_96",
+ (128, SdpaDType::F32) => "sdpa_vector_2pass_2_float_128",
+ (256, SdpaDType::F32) => "sdpa_vector_2pass_2_float_256",
+ (other, _) => {
+ return Err(MetalKernelError::SdpaHeadSizeMismatch {
+ variation: "vector_2pass_2",
+ got: *other,
+ expected: vec![32, 64, 96, 128, 256],
+ })
+ }
+ };
+
+ let b = (q_shape[0] * q_shape[1]) as i32;
+
+ let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name_pass2)?;
+ let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ // q = (bs, qhead, seq, hidden)
+ // k/v = (bs, kv_head, kv_seq, hidden)
+
+ set_params!(encoder, (intermediate, sums, maxs, output));
+
+ let grid_dims = MTLSize {
+ width: 1,
+ height: b as u64,
+ depth: 1,
+ };
+ let group_dims = MTLSize {
+ width: 1024,
+ height: 1,
+ depth: 1,
+ };
+ encoder.use_resource(intermediate, metal::MTLResourceUsage::Write);
+ encoder.use_resource(sums, metal::MTLResourceUsage::Write);
+ encoder.use_resource(maxs, metal::MTLResourceUsage::Write);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+
+ encoder.dispatch_thread_groups(grid_dims, group_dims);
+ }
+ Ok(())
+}
+
#[allow(clippy::too_many_arguments)]
pub fn call_im2col1d_strided(
device: &Device,
diff --git a/candle-metal-kernels/src/scaled_dot_product_attention.metal b/candle-metal-kernels/src/scaled_dot_product_attention.metal
index 1abb9f08..0453e0d1 100644
--- a/candle-metal-kernels/src/scaled_dot_product_attention.metal
+++ b/candle-metal-kernels/src/scaled_dot_product_attention.metal
@@ -47,6 +47,8 @@ struct MLXScaledDotProductAttentionParams {
// ============ "mlx/backend/metal/kernels/scaled_dot_product_attention_params.sdpa_vector"
+constant bool sdpa_vector_has_mask [[function_constant(20)]];
+
template <typename T, int D>
[[kernel]] void sdpa_vector(
const device T* queries [[buffer(0)]],
@@ -59,14 +61,16 @@ template <typename T, int D>
const constant size_t& v_stride,
const constant float& scale,
const constant float& softcapping,
+ const device bool* mask [[function_constant(sdpa_vector_has_mask)]],
+ const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]],
+ const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 32;
constexpr int BD = 32;
constexpr int elem_per_thread = D / BD;
-
- const int stride = BN * D;
+ constexpr int stride = BN * D;
typedef float U;
@@ -84,6 +88,9 @@ template <typename T, int D>
queries += head_idx * D + simd_lid * elem_per_thread;
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
+ if (sdpa_vector_has_mask) {
+ mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride;
+ }
out += head_idx * D + simd_gid * elem_per_thread;
// Read the query and 0 the output accumulator
@@ -99,40 +106,41 @@ template <typename T, int D>
// For each key
for (int i = simd_gid; i < N; i += BN) {
- // Read the key
- for (int i = 0; i < elem_per_thread; i++) {
- k[i] = keys[i];
- }
+ if (!sdpa_vector_has_mask || mask[0]) {
+ // Read the key
+ for (int j = 0; j < elem_per_thread; j++) {
+ k[j] = keys[j];
+ }
- // Compute the i-th score
- U score = 0;
- for (int i = 0; i < elem_per_thread; i++) {
- score += q[i] * k[i];
- }
- score = simd_sum(score);
- if (softcapping != 1.) {
- score = precise::tanh(score);
- score = score * softcapping;
- }
+ // Compute the i-th score
+ U score = 0;
+ for (int j = 0; j < elem_per_thread; j++) {
+ score += q[j] * k[j];
+ }
+ score = simd_sum(score);
+ if (softcapping != 1.) {
+ score = precise::tanh(score);
+ score = score * softcapping;
+ }
- // Update the accumulators
- U new_max = max(max_score, score);
- U factor = fast::exp(max_score - new_max);
- U exp_score = fast::exp(score - new_max);
+ // Update the accumulators
+ U new_max = max(max_score, score);
+ U factor = fast::exp(max_score - new_max);
+ U exp_score = fast::exp(score - new_max);
- max_score = new_max;
- sum_exp_score = sum_exp_score * factor + exp_score;
+ max_score = new_max;
+ sum_exp_score = sum_exp_score * factor + exp_score;
- // Update the output accumulator
- for (int i = 0; i < elem_per_thread; i++) {
- o[i] = o[i] * factor + exp_score * values[i];
+ // Update the output accumulator
+ for (int j = 0; j < elem_per_thread; j++) {
+ o[j] = o[j] * factor + exp_score * values[j];
+ }
}
// Move the pointers to the next kv
keys += stride;
values += stride;
}
- threadgroup_barrier(mem_flags::mem_threadgroup);
// Each thread has a partial part of the output so we need to combine them.
@@ -163,6 +171,164 @@ template <typename T, int D>
}
}
+template <typename T, int D>
+[[kernel]] void sdpa_vector_2pass_1(
+ const device T* queries [[buffer(0)]],
+ const device T* keys [[buffer(1)]],
+ const device T* values [[buffer(2)]],
+ device float* out [[buffer(3)]],
+ device float* sums [[buffer(4)]],
+ device float* maxs [[buffer(5)]],
+ const constant int& gqa_factor,
+ const constant int& N,
+ const constant size_t& k_stride,
+ const constant size_t& v_stride,
+ const constant float& scale,
+ const constant float& softcapping,
+ const device bool* mask [[function_constant(sdpa_vector_has_mask)]],
+ const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]],
+ const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]],
+ uint3 tid [[threadgroup_position_in_grid]],
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
+ uint simd_lid [[thread_index_in_simdgroup]]) {
+ constexpr int BN = 8;
+ constexpr int BD = 32;
+ constexpr int elem_per_thread = D / BD;
+ constexpr int stride = BN * D;
+ constexpr int blocks = 32;
+
+ typedef float U;
+
+ thread U q[elem_per_thread];
+ thread U k[elem_per_thread];
+ thread U o[elem_per_thread];
+
+ threadgroup U outputs[BN * BD];
+ threadgroup U max_scores[BN];
+ threadgroup U sum_exp_scores[BN];
+
+ // Adjust positions
+ const int block_idx = tid.z;
+ const int head_idx = tid.y;
+ const int kv_head_idx = head_idx / gqa_factor;
+ queries += head_idx * D + simd_lid * elem_per_thread;
+ keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
+ simd_lid * elem_per_thread;
+ values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D +
+ simd_lid * elem_per_thread;
+ out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread;
+ if (sdpa_vector_has_mask) {
+ mask += head_idx * mask_head_stride +
+ (block_idx * BN + simd_gid) * mask_seq_stride;
+ }
+ sums += head_idx * blocks + block_idx;
+ maxs += head_idx * blocks + block_idx;
+
+ // Read the query and 0 the output accumulator
+ for (int i = 0; i < elem_per_thread; i++) {
+ q[i] = static_cast<U>(scale) * queries[i];
+ }
+ for (int i = 0; i < elem_per_thread; i++) {
+ o[i] = 0;
+ }
+
+ U max_score = -1e9;
+ U sum_exp_score = 0;
+
+ // For each key
+ for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
+ if (!sdpa_vector_has_mask || mask[0]) {
+ // Read the key
+ for (int i = 0; i < elem_per_thread; i++) {
+ k[i] = keys[i];
+ }
+
+ // Compute the i-th score
+ U score = 0;
+ for (int i = 0; i < elem_per_thread; i++) {
+ score += q[i] * k[i];
+ }
+ score = simd_sum(score);
+ if (softcapping != 1.) {
+ score = precise::tanh(score);
+ score = score * softcapping;
+ }
+
+ // Update the accumulators
+ U new_max = max(max_score, score);
+ U factor = fast::exp(max_score - new_max);
+ U exp_score = fast::exp(score - new_max);
+
+ max_score = new_max;
+ sum_exp_score = sum_exp_score * factor + exp_score;
+
+ // Update the output accumulator
+ for (int i = 0; i < elem_per_thread; i++) {
+ o[i] = o[i] * factor + exp_score * values[i];
+ }
+ }
+
+ // Move the pointers to the next kv
+ keys += blocks * stride;
+ values += blocks * stride;
+ if (sdpa_vector_has_mask) {
+ mask += BN * blocks * mask_seq_stride;
+ }
+ }
+}
+
+template <typename T, int D>
+[[kernel]] void sdpa_vector_2pass_2(
+ const device float* partials [[buffer(0)]],
+ const device float* sums [[buffer(1)]],
+ const device float* maxs [[buffer(2)]],
+ device T* out [[buffer(3)]],
+ uint3 tid [[threadgroup_position_in_grid]],
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
+ uint simd_lid [[thread_index_in_simdgroup]]) {
+ constexpr int BN = 32;
+ constexpr int BD = 32;
+ constexpr int elem_per_thread = D / BD;
+ constexpr int blocks = 32;
+
+ typedef float U;
+
+ thread U o[elem_per_thread];
+ threadgroup U outputs[BN * BD];
+
+ // Adjust positions
+ const int head_idx = tid.y;
+ partials += head_idx * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
+ sums += head_idx * blocks;
+ maxs += head_idx * blocks;
+ out += head_idx * D + simd_gid * elem_per_thread;
+
+ // First everybody reads the max and sum_exp
+ U max_score = maxs[simd_lid];
+ U new_max = simd_max(max_score);
+ U factor = fast::exp(max_score - new_max);
+ U sum_exp_score = simd_sum(sums[simd_lid] * factor);
+
+ // Now read the block into registers and then use shared memory to transpose
+ // it
+ for (int i = 0; i < elem_per_thread; i++) {
+ o[i] = partials[i];
+ }
+ for (int i = 0; i < elem_per_thread; i++) {
+ outputs[simd_lid * BD + simd_gid] = o[i];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+ // And write the output
+ if (simd_lid == 0) {
+ for (int i = 0; i < elem_per_thread; i++) {
+ out[i] = static_cast<T>(o[i]);
+ }
+ }
+}
+
// ============ "mlx/backend/metal/kernels/steel/defines.h"
#define STEEL_CONST static constant constexpr const
@@ -1238,9 +1404,41 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2);
const constant size_t& v_stride, \
const constant float& scale, \
const constant float& softcapping, \
+ const device bool* mask [[function_constant(sdpa_vector_has_mask)]],, \
+ const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], \
+ const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], \
+ uint3 tid [[threadgroup_position_in_grid]], \
+ uint simd_gid [[simdgroup_index_in_threadgroup]], \
+ uint simd_lid [[thread_index_in_simdgroup]]); \
+ template [[host_name("sdpa_vector_2pass_1_" #type "_" #head_dim)]] \
+ [[kernel]] void sdpa_vector_2pass_1<type, head_dim>( \
+ const device type* queries [[buffer(0)]], \
+ const device type* keys [[buffer(1)]], \
+ const device type* values [[buffer(2)]], \
+ device float* out [[buffer(3)]], \
+ device float* sums [[buffer(4)]], \
+ device float* maxs [[buffer(5)]], \
+ const constant int& gqa_factor, \
+ const constant int& N, \
+ const constant size_t& k_stride, \
+ const constant size_t& v_stride, \
+ const constant float& scale, \
+ const constant float& softcapping, \
+ const device bool* mask [[function_constant(sdpa_vector_has_mask)]],, \
+ const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], \
+ const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], \
+ uint3 tid [[threadgroup_position_in_grid]], \
+ uint simd_gid [[simdgroup_index_in_threadgroup]], \
+ uint simd_lid [[thread_index_in_simdgroup]]); \
+ template [[host_name("sdpa_vector_2pass_2_" #type "_" #head_dim)]] \
+ [[kernel]] void sdpa_vector_2pass_2<type, head_dim>( \
+ const device float* partials [[buffer(0)]], \
+ const device float* sums [[buffer(1)]], \
+ const device float* maxs [[buffer(2)]], \
+ device type* out [[buffer(3)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
- uint simd_lid [[thread_index_in_simdgroup]]);
+ uint simd_lid [[thread_index_in_simdgroup]]); \
#define instantiate_sdpa_vector_heads(type) \
instantiate_sdpa_vector(type, 32) \
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index c84e297b..d7f88a0b 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -1074,27 +1074,80 @@ impl candle::CustomOp3 for Sdpa {
let command_buffer = q.device().command_buffer()?;
if supports_sdpa_vector {
- command_buffer.set_label("vector_attention");
- candle_metal_kernels::call_sdpa_vector(
- q.device().device(),
- &command_buffer,
- q.device().kernels(),
- q_l.start_offset(),
- q_l.dims(),
- q.buffer(),
- k_l.start_offset(),
- k_l.dims(),
- k_l.stride(),
- k.buffer(),
- v_l.start_offset(),
- v_l.stride(),
- v.buffer(),
- &output,
- self.scale,
- self.softcapping,
- itype,
- )
- .map_err(candle::Error::wrap)?;
+ // Route to the 2 pass fused attention if the k seqlen is large.
+ // https://github.com/ml-explore/mlx/pull/1597
+ const TWO_PASS_K_THRESHOLD: usize = 1024;
+ if k_l.dim(2)? >= TWO_PASS_K_THRESHOLD {
+ let mut intermediate_shape = [
+ &out_dims[0..out_dims.len() - 2],
+ &[candle_metal_kernels::SDPA_2PASS_BLOCKS],
+ &[out_dims[out_dims.len() - 1]],
+ ]
+ .concat();
+ let intermediate = device.new_buffer(
+ intermediate_shape.iter().product::<usize>(),
+ DType::F32,
+ "sdpa_2pass_intermediate",
+ )?;
+ let _ = intermediate_shape.pop().unwrap();
+ let sums = device.new_buffer(
+ intermediate_shape.iter().product::<usize>(),
+ DType::F32,
+ "sdpa_2pass_sums",
+ )?;
+ let maxs = device.new_buffer(
+ intermediate_shape.iter().product::<usize>(),
+ DType::F32,
+ "sdpa_2pass_maxs",
+ )?;
+
+ command_buffer.set_label("vector_attention");
+ candle_metal_kernels::call_sdpa_vector_2pass(
+ q.device().device(),
+ &command_buffer,
+ q.device().kernels(),
+ q_l.start_offset(),
+ q_l.dims(),
+ q.buffer(),
+ k_l.start_offset(),
+ k_l.dims(),
+ k_l.stride(),
+ k.buffer(),
+ v_l.start_offset(),
+ v_l.stride(),
+ v.buffer(),
+ &output,
+ &intermediate,
+ &sums,
+ &maxs,
+ self.scale,
+ self.softcapping,
+ itype,
+ )
+ .map_err(candle::Error::wrap)?;
+ } else {
+ command_buffer.set_label("vector_attention");
+ candle_metal_kernels::call_sdpa_vector(
+ q.device().device(),
+ &command_buffer,
+ q.device().kernels(),
+ q_l.start_offset(),
+ q_l.dims(),
+ q.buffer(),
+ k_l.start_offset(),
+ k_l.dims(),
+ k_l.stride(),
+ k.buffer(),
+ v_l.start_offset(),
+ v_l.stride(),
+ v.buffer(),
+ &output,
+ self.scale,
+ self.softcapping,
+ itype,
+ )
+ .map_err(candle::Error::wrap)?;
+ }
} else if supports_sdpa_full {
if q_l.dim(2)? != k_l.dim(2)? {
candle::bail!(