summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/scaled_dot_product_attention.metal
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/scaled_dot_product_attention.metal')
-rw-r--r--candle-metal-kernels/src/scaled_dot_product_attention.metal252
1 files changed, 225 insertions, 27 deletions
diff --git a/candle-metal-kernels/src/scaled_dot_product_attention.metal b/candle-metal-kernels/src/scaled_dot_product_attention.metal
index 1abb9f08..0453e0d1 100644
--- a/candle-metal-kernels/src/scaled_dot_product_attention.metal
+++ b/candle-metal-kernels/src/scaled_dot_product_attention.metal
@@ -47,6 +47,8 @@ struct MLXScaledDotProductAttentionParams {
// ============ "mlx/backend/metal/kernels/scaled_dot_product_attention_params.sdpa_vector"
+constant bool sdpa_vector_has_mask [[function_constant(20)]];
+
template <typename T, int D>
[[kernel]] void sdpa_vector(
const device T* queries [[buffer(0)]],
@@ -59,14 +61,16 @@ template <typename T, int D>
const constant size_t& v_stride,
const constant float& scale,
const constant float& softcapping,
+ const device bool* mask [[function_constant(sdpa_vector_has_mask)]],
+ const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]],
+ const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 32;
constexpr int BD = 32;
constexpr int elem_per_thread = D / BD;
-
- const int stride = BN * D;
+ constexpr int stride = BN * D;
typedef float U;
@@ -84,6 +88,9 @@ template <typename T, int D>
queries += head_idx * D + simd_lid * elem_per_thread;
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
+ if (sdpa_vector_has_mask) {
+ mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride;
+ }
out += head_idx * D + simd_gid * elem_per_thread;
// Read the query and 0 the output accumulator
@@ -99,40 +106,41 @@ template <typename T, int D>
// For each key
for (int i = simd_gid; i < N; i += BN) {
- // Read the key
- for (int i = 0; i < elem_per_thread; i++) {
- k[i] = keys[i];
- }
+ if (!sdpa_vector_has_mask || mask[0]) {
+ // Read the key
+ for (int j = 0; j < elem_per_thread; j++) {
+ k[j] = keys[j];
+ }
- // Compute the i-th score
- U score = 0;
- for (int i = 0; i < elem_per_thread; i++) {
- score += q[i] * k[i];
- }
- score = simd_sum(score);
- if (softcapping != 1.) {
- score = precise::tanh(score);
- score = score * softcapping;
- }
+ // Compute the i-th score
+ U score = 0;
+ for (int j = 0; j < elem_per_thread; j++) {
+ score += q[j] * k[j];
+ }
+ score = simd_sum(score);
+ if (softcapping != 1.) {
+ score = precise::tanh(score);
+ score = score * softcapping;
+ }
- // Update the accumulators
- U new_max = max(max_score, score);
- U factor = fast::exp(max_score - new_max);
- U exp_score = fast::exp(score - new_max);
+ // Update the accumulators
+ U new_max = max(max_score, score);
+ U factor = fast::exp(max_score - new_max);
+ U exp_score = fast::exp(score - new_max);
- max_score = new_max;
- sum_exp_score = sum_exp_score * factor + exp_score;
+ max_score = new_max;
+ sum_exp_score = sum_exp_score * factor + exp_score;
- // Update the output accumulator
- for (int i = 0; i < elem_per_thread; i++) {
- o[i] = o[i] * factor + exp_score * values[i];
+ // Update the output accumulator
+ for (int j = 0; j < elem_per_thread; j++) {
+ o[j] = o[j] * factor + exp_score * values[j];
+ }
}
// Move the pointers to the next kv
keys += stride;
values += stride;
}
- threadgroup_barrier(mem_flags::mem_threadgroup);
// Each thread has a partial part of the output so we need to combine them.
@@ -163,6 +171,164 @@ template <typename T, int D>
}
}
+template <typename T, int D>
+[[kernel]] void sdpa_vector_2pass_1(
+ const device T* queries [[buffer(0)]],
+ const device T* keys [[buffer(1)]],
+ const device T* values [[buffer(2)]],
+ device float* out [[buffer(3)]],
+ device float* sums [[buffer(4)]],
+ device float* maxs [[buffer(5)]],
+ const constant int& gqa_factor,
+ const constant int& N,
+ const constant size_t& k_stride,
+ const constant size_t& v_stride,
+ const constant float& scale,
+ const constant float& softcapping,
+ const device bool* mask [[function_constant(sdpa_vector_has_mask)]],
+ const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]],
+ const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]],
+ uint3 tid [[threadgroup_position_in_grid]],
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
+ uint simd_lid [[thread_index_in_simdgroup]]) {
+ constexpr int BN = 8;
+ constexpr int BD = 32;
+ constexpr int elem_per_thread = D / BD;
+ constexpr int stride = BN * D;
+ constexpr int blocks = 32;
+
+ typedef float U;
+
+ thread U q[elem_per_thread];
+ thread U k[elem_per_thread];
+ thread U o[elem_per_thread];
+
+ threadgroup U outputs[BN * BD];
+ threadgroup U max_scores[BN];
+ threadgroup U sum_exp_scores[BN];
+
+ // Adjust positions
+ const int block_idx = tid.z;
+ const int head_idx = tid.y;
+ const int kv_head_idx = head_idx / gqa_factor;
+ queries += head_idx * D + simd_lid * elem_per_thread;
+ keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
+ simd_lid * elem_per_thread;
+ values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D +
+ simd_lid * elem_per_thread;
+ out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread;
+ if (sdpa_vector_has_mask) {
+ mask += head_idx * mask_head_stride +
+ (block_idx * BN + simd_gid) * mask_seq_stride;
+ }
+ sums += head_idx * blocks + block_idx;
+ maxs += head_idx * blocks + block_idx;
+
+ // Read the query and 0 the output accumulator
+ for (int i = 0; i < elem_per_thread; i++) {
+ q[i] = static_cast<U>(scale) * queries[i];
+ }
+ for (int i = 0; i < elem_per_thread; i++) {
+ o[i] = 0;
+ }
+
+ U max_score = -1e9;
+ U sum_exp_score = 0;
+
+ // For each key
+ for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
+ if (!sdpa_vector_has_mask || mask[0]) {
+ // Read the key
+ for (int i = 0; i < elem_per_thread; i++) {
+ k[i] = keys[i];
+ }
+
+ // Compute the i-th score
+ U score = 0;
+ for (int i = 0; i < elem_per_thread; i++) {
+ score += q[i] * k[i];
+ }
+ score = simd_sum(score);
+ if (softcapping != 1.) {
+ score = precise::tanh(score);
+ score = score * softcapping;
+ }
+
+ // Update the accumulators
+ U new_max = max(max_score, score);
+ U factor = fast::exp(max_score - new_max);
+ U exp_score = fast::exp(score - new_max);
+
+ max_score = new_max;
+ sum_exp_score = sum_exp_score * factor + exp_score;
+
+ // Update the output accumulator
+ for (int i = 0; i < elem_per_thread; i++) {
+ o[i] = o[i] * factor + exp_score * values[i];
+ }
+ }
+
+ // Move the pointers to the next kv
+ keys += blocks * stride;
+ values += blocks * stride;
+ if (sdpa_vector_has_mask) {
+ mask += BN * blocks * mask_seq_stride;
+ }
+ }
+}
+
+template <typename T, int D>
+[[kernel]] void sdpa_vector_2pass_2(
+ const device float* partials [[buffer(0)]],
+ const device float* sums [[buffer(1)]],
+ const device float* maxs [[buffer(2)]],
+ device T* out [[buffer(3)]],
+ uint3 tid [[threadgroup_position_in_grid]],
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
+ uint simd_lid [[thread_index_in_simdgroup]]) {
+ constexpr int BN = 32;
+ constexpr int BD = 32;
+ constexpr int elem_per_thread = D / BD;
+ constexpr int blocks = 32;
+
+ typedef float U;
+
+ thread U o[elem_per_thread];
+ threadgroup U outputs[BN * BD];
+
+ // Adjust positions
+ const int head_idx = tid.y;
+ partials += head_idx * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
+ sums += head_idx * blocks;
+ maxs += head_idx * blocks;
+ out += head_idx * D + simd_gid * elem_per_thread;
+
+ // First everybody reads the max and sum_exp
+ U max_score = maxs[simd_lid];
+ U new_max = simd_max(max_score);
+ U factor = fast::exp(max_score - new_max);
+ U sum_exp_score = simd_sum(sums[simd_lid] * factor);
+
+ // Now read the block into registers and then use shared memory to transpose
+ // it
+ for (int i = 0; i < elem_per_thread; i++) {
+ o[i] = partials[i];
+ }
+ for (int i = 0; i < elem_per_thread; i++) {
+ outputs[simd_lid * BD + simd_gid] = o[i];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+ // And write the output
+ if (simd_lid == 0) {
+ for (int i = 0; i < elem_per_thread; i++) {
+ out[i] = static_cast<T>(o[i]);
+ }
+ }
+}
+
// ============ "mlx/backend/metal/kernels/steel/defines.h"
#define STEEL_CONST static constant constexpr const
@@ -1238,9 +1404,41 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2);
const constant size_t& v_stride, \
const constant float& scale, \
const constant float& softcapping, \
+ const device bool* mask [[function_constant(sdpa_vector_has_mask)]],, \
+ const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], \
+ const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], \
+ uint3 tid [[threadgroup_position_in_grid]], \
+ uint simd_gid [[simdgroup_index_in_threadgroup]], \
+ uint simd_lid [[thread_index_in_simdgroup]]); \
+ template [[host_name("sdpa_vector_2pass_1_" #type "_" #head_dim)]] \
+ [[kernel]] void sdpa_vector_2pass_1<type, head_dim>( \
+ const device type* queries [[buffer(0)]], \
+ const device type* keys [[buffer(1)]], \
+ const device type* values [[buffer(2)]], \
+ device float* out [[buffer(3)]], \
+ device float* sums [[buffer(4)]], \
+ device float* maxs [[buffer(5)]], \
+ const constant int& gqa_factor, \
+ const constant int& N, \
+ const constant size_t& k_stride, \
+ const constant size_t& v_stride, \
+ const constant float& scale, \
+ const constant float& softcapping, \
+ const device bool* mask [[function_constant(sdpa_vector_has_mask)]],, \
+ const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], \
+ const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], \
+ uint3 tid [[threadgroup_position_in_grid]], \
+ uint simd_gid [[simdgroup_index_in_threadgroup]], \
+ uint simd_lid [[thread_index_in_simdgroup]]); \
+ template [[host_name("sdpa_vector_2pass_2_" #type "_" #head_dim)]] \
+ [[kernel]] void sdpa_vector_2pass_2<type, head_dim>( \
+ const device float* partials [[buffer(0)]], \
+ const device float* sums [[buffer(1)]], \
+ const device float* maxs [[buffer(2)]], \
+ device type* out [[buffer(3)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
- uint simd_lid [[thread_index_in_simdgroup]]);
+ uint simd_lid [[thread_index_in_simdgroup]]); \
#define instantiate_sdpa_vector_heads(type) \
instantiate_sdpa_vector(type, 32) \