summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorEric Buehler <65165915+EricLBuehler@users.noreply.github.com>2024-11-05 03:28:00 -0500
committerGitHub <noreply@github.com>2024-11-05 09:28:00 +0100
commite2b6b367fa852ed30ac532f8d77cd8479c7ed092 (patch)
tree41321e646a0ee9abef88122b202bd940240ecae6 /candle-metal-kernels
parent6454597943599dd6df787a0d5f2446c5724d850a (diff)
downloadcandle-e2b6b367fa852ed30ac532f8d77cd8479c7ed092.tar.gz
candle-e2b6b367fa852ed30ac532f8d77cd8479c7ed092.tar.bz2
candle-e2b6b367fa852ed30ac532f8d77cd8479c7ed092.zip
Add some fast Metal MLX SDPA kernels (#2584)
* Add some fast Metal MLX SDPA kernels (#32) * Sketch the sdpa kernel * Add full sdpa kernel, * Add test * Add vectorized kernel for decoding * Update tests * Add some docs * Fix sdpa_vector names * Add softcapping for vectorized sdpa * Add softcapping for full sdpa * Add support for head dim 32, 96, 256 * Add support for head dim 32, 96, 256 * Update docs * Add update notice * Clippy and format * Conditional compilation for bf16 * Use it in quantized llama * Some review comments * Use set_params! * Remove unused * Remove feature * Fix metal sdpa for v stride * Remove comma * Add the dim method to layout and shape. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs323
-rw-r--r--candle-metal-kernels/src/scaled_dot_product_attention.metal1257
2 files changed, 1579 insertions, 1 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 222ae8ad..0843cc11 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -8,7 +8,7 @@ use std::sync::RwLock;
pub mod utils;
pub use utils::BufferOffset;
-use utils::{get_block_dims, linear_split, EncoderProvider};
+use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider};
const AFFINE: &str = include_str!("affine.metal");
const BINARY: &str = include_str!("binary.metal");
@@ -25,6 +25,7 @@ const REDUCE: &str = include_str!("reduce.metal");
const SORT: &str = include_str!("sort.metal");
const TERNARY: &str = include_str!("ternary.metal");
const UNARY: &str = include_str!("unary.metal");
+const SDPA: &str = include_str!("scaled_dot_product_attention.metal");
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
@@ -42,6 +43,7 @@ pub enum Source {
Sort,
Ternary,
Unary,
+ Sdpa,
}
pub mod copy2d {
@@ -159,6 +161,17 @@ pub enum MetalKernelError {
rhs_stride: Vec<usize>,
mnk: (usize, usize, usize),
},
+ #[error("Sdpa {variation} head size was {got}, expectd {expected:?}")]
+ SdpaHeadSizeMismatch {
+ variation: &'static str,
+ got: usize,
+ expected: Vec<usize>,
+ },
+ #[error("Sdpa {variation} got dtype {got:?}")]
+ SdpaHeadDTypeMismatch {
+ variation: &'static str,
+ got: SdpaDType,
+ },
}
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
@@ -207,6 +220,7 @@ impl Kernels {
Source::Sort => SORT,
Source::Ternary => TERNARY,
Source::Unary => UNARY,
+ Source::Sdpa => SDPA,
Source::Mfa => panic!("Invalid lib"),
}
}
@@ -1627,6 +1641,313 @@ pub fn call_gemm(
Ok(())
}
+#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
+pub enum SdpaDType {
+ BF16,
+ F16,
+ F32,
+}
+
+/// SDPA full is supported when:
+/// - q head dim == 64, 128
+/// - no mask
+/// - q heads == kv heads
+/// - final type != bf16 (TODO maybe just template this kernel too?)
+/// - q,k,v are contiguous
+#[allow(clippy::too_many_arguments)]
+pub fn call_sdpa_full(
+ device: &Device,
+ ep: impl EncoderProvider,
+ kernels: &Kernels,
+ q_offset: usize,
+ q_shape: &[usize],
+ q_buffer: &Buffer,
+ k_offset: usize,
+ k_buffer: &Buffer,
+ v_offset: usize,
+ v_buffer: &Buffer,
+ output: &Buffer,
+ alpha: f32,
+ softcapping: f32,
+ itype: SdpaDType,
+) -> Result<(), MetalKernelError> {
+ #[derive(Debug)]
+ #[repr(C)]
+ struct MLXFastAttentionParams {
+ m: i32,
+ n: i32,
+ k: i32,
+
+ ldq: i32, // ldq == ldo
+ ldk: i32,
+ ldv: i32,
+ lds: i32,
+ ldo: i32,
+
+ tiles_n: i32,
+ tiles_m: i32,
+
+ batch_stride_q: i32,
+ batch_stride_k: i32,
+ batch_stride_v: i32,
+ batch_stride_o: i32,
+
+ swizzle_log: i32,
+ gemm_n_iterations_aligned: i32,
+ gemm_k_iterations_aligned: i32,
+ gemm_sv_m_block_iterations: i32,
+
+ batch_ndim: i32,
+ alpha: f32,
+ softcapping: f32,
+ }
+
+ let bk = q_shape.last().unwrap();
+
+ const BN: usize = 16;
+ const BM: usize = 16;
+ const WM: usize = 2;
+ const WN: usize = 2;
+
+ let name = match (bk, itype) {
+ (32, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_half",
+ (64, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_half",
+ (96, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_half",
+ (128, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_half",
+ (256, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_half",
+ (32, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_float",
+ (64, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_float",
+ (96, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_float",
+ (128, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_float",
+ (256, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_float",
+ (other, SdpaDType::F16 | SdpaDType::F32) => {
+ return Err(MetalKernelError::SdpaHeadSizeMismatch {
+ variation: "full",
+ got: *other,
+ expected: vec![32, 64, 96, 128, 256],
+ })
+ }
+ (_, SdpaDType::BF16) => {
+ return Err(MetalKernelError::SdpaHeadDTypeMismatch {
+ variation: "full",
+ got: SdpaDType::BF16,
+ })
+ }
+ };
+
+ let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?;
+ let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ // q = (bs, qhead, seq, hidden)
+ // k/v = (bs, kv_head, seq, hidden)
+
+ let qseq = q_shape[q_shape.len() - 2];
+
+ let m = q_shape[q_shape.len() - 2];
+ let n = m;
+ let k = q_shape[q_shape.len() - 1];
+ let bs_out = q_shape[0] * q_shape[1];
+
+ let batch_shape = [q_shape[0] * q_shape[1]];
+ let dk = q_shape[q_shape.len() - 1];
+ let ldq = dk;
+ let ldk = dk;
+ let ldv = dk;
+ let lds = BN;
+ let ldo = dk;
+
+ let tn = 1;
+ let tm = (m + BM - 1) / BM;
+
+ let b_stride_q = dk * qseq;
+ let b_stride_k = dk * qseq;
+ let b_stride_v = dk * qseq;
+ let b_stride_o = dk * qseq;
+ let swizzle_log = 0;
+ let gemm_n_iterations_aligned = (n + BN - 1) / BN;
+ let gemm_k_iterations_aligned = (k + bk - 1) / bk;
+ let gemm_sv_m_block_iterations = (m + BM - 1) / BM;
+ let batch_ndim = batch_shape.len();
+
+ let alpha = if softcapping != 1. {
+ alpha / softcapping
+ } else {
+ alpha
+ };
+
+ let params = MLXFastAttentionParams {
+ m: m as i32,
+ n: n as i32,
+ k: k as i32,
+ ldq: ldq as i32,
+ ldk: ldk as i32,
+ ldv: ldv as i32,
+ lds: lds as i32,
+ ldo: ldo as i32,
+ tiles_n: tn,
+ tiles_m: tm as i32,
+ batch_stride_q: b_stride_q as i32,
+ batch_stride_k: b_stride_k as i32,
+ batch_stride_v: b_stride_v as i32,
+ batch_stride_o: b_stride_o as i32,
+ swizzle_log,
+ gemm_n_iterations_aligned: gemm_n_iterations_aligned as i32,
+ gemm_k_iterations_aligned: gemm_k_iterations_aligned as i32,
+ gemm_sv_m_block_iterations: gemm_sv_m_block_iterations as i32,
+ batch_ndim: batch_ndim as i32,
+ alpha,
+ softcapping,
+ };
+ let batch_strides = [b_stride_q, b_stride_k, b_stride_v, b_stride_o];
+
+ impl EncoderParam for MLXFastAttentionParams {
+ fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
+ encoder.set_bytes(
+ position,
+ core::mem::size_of::<MLXFastAttentionParams>() as u64,
+ &data as *const MLXFastAttentionParams as *const c_void,
+ );
+ }
+ }
+
+ set_params!(
+ encoder,
+ (
+ (q_buffer, q_offset),
+ (k_buffer, k_offset),
+ (v_buffer, v_offset),
+ output,
+ params,
+ &batch_shape[..],
+ &batch_strides[..]
+ )
+ );
+
+ let grid_dims = MTLSize {
+ width: 1,
+ height: tm as u64,
+ depth: bs_out as u64,
+ };
+ let group_dims = MTLSize {
+ width: 32,
+ height: WM as u64,
+ depth: WN as u64,
+ };
+ encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(grid_dims, group_dims);
+ Ok(())
+}
+
+/// SDPA full is supported when:
+/// - q head dim == 64, 96, 128
+/// - no mask
+/// - q,k,v are contiguous
+#[allow(clippy::too_many_arguments)]
+pub fn call_sdpa_vector(
+ device: &Device,
+ ep: impl EncoderProvider,
+ kernels: &Kernels,
+ q_offset: usize,
+ q_shape: &[usize],
+ q_buffer: &Buffer,
+ k_offset: usize,
+ k_shape: &[usize],
+ k_stride: &[usize],
+ k_buffer: &Buffer,
+ v_offset: usize,
+ v_stride: &[usize],
+ v_buffer: &Buffer,
+ output: &Buffer,
+ alpha: f32,
+ softcapping: f32,
+ itype: SdpaDType,
+) -> Result<(), MetalKernelError> {
+ let bk = q_shape.last().unwrap();
+
+ let gqa_factor = (q_shape[1] / k_shape[1]) as i32;
+ let n = k_shape[2] as i32;
+ let b = (q_shape[0] * q_shape[1]) as i32;
+ let kstride = k_stride[1];
+ let vstride = v_stride[1];
+
+ let name = match (bk, itype) {
+ (32, SdpaDType::F16) => "sdpa_vector_float16_t_32",
+ (64, SdpaDType::F16) => "sdpa_vector_float16_t_64",
+ (96, SdpaDType::F16) => "sdpa_vector_float16_t_96",
+ (128, SdpaDType::F16) => "sdpa_vector_float16_t_128",
+ (256, SdpaDType::F16) => "sdpa_vector_float16_t_256",
+ (32, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_32",
+ (64, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_64",
+ (96, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_96",
+ (128, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_128",
+ (256, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_256",
+ (32, SdpaDType::F32) => "sdpa_vector_float_32",
+ (64, SdpaDType::F32) => "sdpa_vector_float_64",
+ (96, SdpaDType::F32) => "sdpa_vector_float_96",
+ (128, SdpaDType::F32) => "sdpa_vector_float_128",
+ (256, SdpaDType::F32) => "sdpa_vector_float_256",
+ (other, _) => {
+ return Err(MetalKernelError::SdpaHeadSizeMismatch {
+ variation: "vector",
+ got: *other,
+ expected: vec![32, 64, 96, 128, 256],
+ })
+ }
+ };
+
+ let alpha = if softcapping != 1. {
+ alpha / softcapping
+ } else {
+ alpha
+ };
+
+ let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?;
+ let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ // q = (bs, qhead, seq, hidden)
+ // k/v = (bs, kv_head, kv_seq, hidden)
+
+ set_params!(
+ encoder,
+ (
+ (q_buffer, q_offset),
+ (k_buffer, k_offset),
+ (v_buffer, v_offset),
+ output,
+ gqa_factor,
+ n,
+ kstride,
+ vstride,
+ alpha,
+ softcapping
+ )
+ );
+
+ let grid_dims = MTLSize {
+ width: 1,
+ height: b as u64,
+ depth: 1 as u64,
+ };
+ let group_dims = MTLSize {
+ width: 1024,
+ height: 1,
+ depth: 1,
+ };
+ encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(grid_dims, group_dims);
+ Ok(())
+}
+
#[allow(clippy::too_many_arguments)]
pub fn call_im2col1d_strided(
device: &Device,
diff --git a/candle-metal-kernels/src/scaled_dot_product_attention.metal b/candle-metal-kernels/src/scaled_dot_product_attention.metal
new file mode 100644
index 00000000..1abb9f08
--- /dev/null
+++ b/candle-metal-kernels/src/scaled_dot_product_attention.metal
@@ -0,0 +1,1257 @@
+// Updated from MLX commit has f70764a
+
+#include <metal_stdlib>
+#include <metal_simdgroup>
+
+using namespace metal;
+
+// ============ "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
+
+struct MLXFastAttentionParams {
+ const int M;
+ const int N;
+ const int K;
+
+ const int ldq; // ldq == ldo
+ const int ldk;
+ const int ldv;
+ const int lds;
+ const int ldo;
+
+ const int tiles_n;
+ const int tiles_m;
+
+ const int batch_stride_q;
+ const int batch_stride_k;
+ const int batch_stride_v;
+ const int batch_stride_o;
+
+ const int swizzle_log;
+ const int gemm_n_iterations_aligned;
+ const int gemm_k_iterations_aligned;
+ const int gemm_sv_m_block_iterations;
+
+ const int batch_ndim;
+ const float alpha;
+ const float softcapping;
+};
+
+struct MLXScaledDotProductAttentionParams {
+ // Associated dimensions & transposition information
+ const uint QUERY_SEQUENCE_LENGTH = 1;
+ const uint N_Q_HEADS = 32;
+ const uint N_KV_HEADS = 32;
+ const uint KV_TILES = 1;
+ const float INV_ALPHA = 0.08838834764831843f;
+};
+
+// ============ "mlx/backend/metal/kernels/scaled_dot_product_attention_params.sdpa_vector"
+
+template <typename T, int D>
+[[kernel]] void sdpa_vector(
+ const device T* queries [[buffer(0)]],
+ const device T* keys [[buffer(1)]],
+ const device T* values [[buffer(2)]],
+ device T* out [[buffer(3)]],
+ const constant int& gqa_factor,
+ const constant int& N,
+ const constant size_t& k_stride,
+ const constant size_t& v_stride,
+ const constant float& scale,
+ const constant float& softcapping,
+ uint3 tid [[threadgroup_position_in_grid]],
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
+ uint simd_lid [[thread_index_in_simdgroup]]) {
+ constexpr int BN = 32;
+ constexpr int BD = 32;
+ constexpr int elem_per_thread = D / BD;
+
+ const int stride = BN * D;
+
+ typedef float U;
+
+ thread U q[elem_per_thread];
+ thread U k[elem_per_thread];
+ thread U o[elem_per_thread];
+
+ threadgroup U outputs[BN * BD];
+ threadgroup U max_scores[BN];
+ threadgroup U sum_exp_scores[BN];
+
+ // Adjust positions
+ const int head_idx = tid.y;
+ const int kv_head_idx = head_idx / gqa_factor;
+ queries += head_idx * D + simd_lid * elem_per_thread;
+ keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
+ values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
+ out += head_idx * D + simd_gid * elem_per_thread;
+
+ // Read the query and 0 the output accumulator
+ for (int i = 0; i < elem_per_thread; i++) {
+ q[i] = static_cast<U>(scale) * queries[i];
+ }
+ for (int i = 0; i < elem_per_thread; i++) {
+ o[i] = 0;
+ }
+
+ U max_score = -INFINITY;
+ U sum_exp_score = 0;
+
+ // For each key
+ for (int i = simd_gid; i < N; i += BN) {
+ // Read the key
+ for (int i = 0; i < elem_per_thread; i++) {
+ k[i] = keys[i];
+ }
+
+ // Compute the i-th score
+ U score = 0;
+ for (int i = 0; i < elem_per_thread; i++) {
+ score += q[i] * k[i];
+ }
+ score = simd_sum(score);
+ if (softcapping != 1.) {
+ score = precise::tanh(score);
+ score = score * softcapping;
+ }
+
+ // Update the accumulators
+ U new_max = max(max_score, score);
+ U factor = fast::exp(max_score - new_max);
+ U exp_score = fast::exp(score - new_max);
+
+ max_score = new_max;
+ sum_exp_score = sum_exp_score * factor + exp_score;
+
+ // Update the output accumulator
+ for (int i = 0; i < elem_per_thread; i++) {
+ o[i] = o[i] * factor + exp_score * values[i];
+ }
+
+ // Move the pointers to the next kv
+ keys += stride;
+ values += stride;
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // Each thread has a partial part of the output so we need to combine them.
+
+ // First let's communicate the max and sum_exp
+ if (simd_lid == 0) {
+ max_scores[simd_gid] = max_score;
+ sum_exp_scores[simd_gid] = sum_exp_score;
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ max_score = max_scores[simd_lid];
+ U new_max = simd_max(max_score);
+ U factor = fast::exp(max_score - new_max);
+ sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
+
+ // Now we need to aggregate all the outputs
+ for (int i = 0; i < elem_per_thread; i++) {
+ outputs[simd_lid * BD + simd_gid] = o[i];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+ // And write the output
+ if (simd_lid == 0) {
+ for (int i = 0; i < elem_per_thread; i++) {
+ out[i] = static_cast<T>(o[i]);
+ }
+ }
+}
+
+// ============ "mlx/backend/metal/kernels/steel/defines.h"
+
+#define STEEL_CONST static constant constexpr const
+#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
+
+// ============ "mlx/backend/metal/kernels/steel/gemm/transforms.h"
+
+template <typename OutT, typename InT>
+struct TransformNone {
+ static METAL_FUNC OutT apply(InT x) {
+ return static_cast<OutT>(x);
+ }
+
+ static METAL_FUNC OutT apply(InT x, OutT) {
+ return static_cast<OutT>(x);
+ }
+};
+
+template <typename OutT, typename InT>
+struct TransformAdd {
+ TransformAdd(const float, const float) {}
+
+ static METAL_FUNC OutT apply(InT x) {
+ return static_cast<OutT>(x);
+ }
+
+ static METAL_FUNC OutT apply(InT x, OutT c) {
+ return static_cast<OutT>(x) + c;
+ }
+};
+
+template <typename OutT, typename InT>
+struct TransformAxpby {
+ const float alpha;
+ const float beta;
+
+ TransformAxpby(const float alpha_, const float beta_)
+ : alpha(alpha_), beta(beta_) {}
+
+ static METAL_FUNC OutT apply(InT x) {
+ return static_cast<OutT>(x);
+ }
+
+ METAL_FUNC OutT apply(InT x, OutT c) const {
+ return static_cast<OutT>(x * alpha + (beta * c));
+ }
+};
+
+template <typename T>
+struct AccumHelper {
+ typedef float accum_type;
+};
+
+struct BlockSwizzle {
+ static METAL_FUNC int2
+ swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
+ const int tid_x = (tid.x) >> swizzle_log;
+ const int tid_y =
+ ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
+ return int2(tid_x, tid_y);
+ }
+};
+
+// ============ "mlx/backend/metal/kernels/utils.h"
+
+#if defined(__HAVE_BFLOAT__)
+typedef bfloat bfloat16_t;
+#endif
+typedef half float16_t;
+
+METAL_FUNC ulong2 elem_to_loc_broadcast(
+ uint elem,
+ constant const int* shape,
+ constant const size_t* a_strides,
+ constant const size_t* b_strides,
+ int ndim) {
+ ulong loc_a{0};
+ ulong loc_b{0};
+ for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
+ int pos_in_dim = (elem % shape[i]);
+ elem /= shape[i];
+ loc_a += pos_in_dim * a_strides[i];
+ loc_b += pos_in_dim * b_strides[i];
+ }
+ return ulong2(loc_a, loc_b);
+}
+
+METAL_FUNC ulong3 elem_to_loc_broadcast(
+ uint elem,
+ constant const int* shape,
+ constant const size_t* a_strides,
+ constant const size_t* b_strides,
+ constant const size_t* c_strides,
+ int ndim) {
+ ulong loc_a{0};
+ ulong loc_b{0};
+ ulong loc_c{0};
+ for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
+ int pos_in_dim = (elem % shape[i]);
+ elem /= shape[i];
+ loc_a += pos_in_dim * a_strides[i];
+ loc_b += pos_in_dim * b_strides[i];
+ loc_c += pos_in_dim * c_strides[i];
+ }
+ return ulong3(loc_a, loc_b, loc_c);
+}
+
+// ============ "mlx/backend/metal/kernels/scaled_dot_product_attention_params.metal"
+
+template <
+ typename T,
+ short BROWS,
+ short BCOLS,
+ short dst_ld,
+ short reduction_dim,
+ short tgp_size,
+ short alignment = 1,
+ short n_reads = (BCOLS * BROWS) / (tgp_size),
+ short TCOLS = BCOLS / n_reads,
+ short TROWS = tgp_size / TCOLS>
+struct BlockLoaderFA {
+ STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
+ STEEL_CONST short vec_size = n_reads;
+
+ // Leading dimension for src
+ const int src_ld;
+ const int tile_stride;
+
+ // Thread location indices
+ const short thread_idx;
+ const short bi;
+ const short bj;
+
+ // threadgroup and device memory
+ threadgroup T* dst;
+ const device T* src;
+
+ struct alignas(alignment * sizeof(T)) ReadVector {
+ uint8_t v[sizeof(T) * vec_size];
+ };
+
+ /* Constructor */
+ METAL_FUNC BlockLoaderFA(
+ const device T* src_,
+ const int src_ld_,
+ threadgroup T* dst_,
+ ushort simd_group_id [[simdgroup_index_in_threadgroup]],
+ ushort simd_lane_id [[thread_index_in_simdgroup]])
+ : src_ld(src_ld_),
+ tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
+ thread_idx(simd_group_id * 32 + simd_lane_id),
+ bi(thread_idx / TCOLS),
+ bj(vec_size * (thread_idx % TCOLS)),
+ dst(dst_ + bi * dst_ld + bj),
+ src(src_ + bi * src_ld + bj) {}
+
+ /* Load from device memory into threadgroup memory - without bound checking */
+ METAL_FUNC void load_unsafe() const {
+ STEEL_PRAGMA_UNROLL
+ for (short i = 0; i < BROWS; i += TROWS) {
+ *((threadgroup ReadVector*)(&dst[i * dst_ld])) =
+ *((const device ReadVector*)(&src[i * src_ld]));
+ }
+ }
+
+ /* Load from device memory into threadgroup memory - with bound checking */
+ METAL_FUNC void load_safe(short2 src_tile_dim) const {
+ src_tile_dim = src_tile_dim - short2(bj, bi);
+
+ // Skip loading if thread has no valid reads
+ if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
+ STEEL_PRAGMA_UNROLL
+ for (short i = 0; i < BROWS; i += TROWS) {
+ STEEL_PRAGMA_UNROLL
+ for (short j = 0; j < vec_size; j++) {
+ dst[i * dst_ld + j] = T(0);
+ }
+ }
+ return;
+ }
+
+ // Use fast thread memory for bound checks
+ bool tmp_idx[vec_size];
+ T tmp_val[vec_size];
+
+ STEEL_PRAGMA_UNROLL
+ for (short i = 0; i < BROWS; i += TROWS) {
+ // Make sure tmp_idx only contains valid indices
+ STEEL_PRAGMA_UNROLL
+ for (short j = 0; j < vec_size; j++) {
+ tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
+ }
+
+ // Read valid indices into tmp_val
+ STEEL_PRAGMA_UNROLL
+ for (short j = 0; j < vec_size; j++) {
+ tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
+ }
+
+ // Zero out uneeded values
+ STEEL_PRAGMA_UNROLL
+ for (short j = 0; j < vec_size; j++) {
+ tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
+ }
+
+ // Copy values to threadgroup memory
+ STEEL_PRAGMA_UNROLL
+ for (short j = 0; j < vec_size; j++) {
+ dst[i * dst_ld + j] = tmp_val[j];
+ }
+ }
+ }
+
+ /* Iteration helper */
+ METAL_FUNC void next() {
+ src += tile_stride;
+ }
+ METAL_FUNC void next(short n) {
+ src += n * tile_stride;
+ }
+};
+
+template <bool M_aligned, bool N_aligned, bool K_aligned>
+struct LoopAlignment {};
+
+template <
+ typename T,
+ typename U,
+ int BM,
+ int BN,
+ int BK,
+ int WM,
+ int WN,
+ bool transpose_a,
+ bool transpose_b,
+ short lda_tgp,
+ short ldb_tgp,
+ typename AccumType = float,
+ typename Epilogue = TransformNone<U, AccumType>>
+struct BlockMMAFA {
+ // Warp tile simdgroup matrix strides along M
+ STEEL_CONST short TM_stride = 8 * WM;
+ // Warp tile simdgroup matrix strides along M
+ STEEL_CONST short TN_stride = 8 * WN;
+
+ // Warp tile size along M
+ STEEL_CONST short TM = BM / TM_stride;
+ // Warp tile size along N
+ STEEL_CONST short TN = BN / TN_stride;
+
+ // Strides of A, B along reduction axis
+ STEEL_CONST short simd_stride_a = {
+ transpose_a ? TM_stride : TM_stride * lda_tgp};
+ STEEL_CONST short simd_stride_b = {
+ transpose_b ? TN_stride * ldb_tgp : TN_stride};
+
+ // Jump between elements
+ STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
+ STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
+
+ STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
+ STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
+
+ // Simdgroup matrices
+ simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
+ simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
+ simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
+ simdgroup_matrix<AccumType, 8, 8>(0)};
+
+ // Offsets within threadgroup
+ const short tm;
+ const short tn;
+
+ short sm;
+ short sn;
+
+ ushort sid;
+ ushort slid;
+
+ short As_offset;
+ short Bs_offset;
+
+ /* Constructor */
+ METAL_FUNC BlockMMAFA(
+ ushort simd_group_id [[simdgroup_index_in_threadgroup]],
+ ushort simd_lane_id [[thread_index_in_simdgroup]])
+ : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
+ // Determine thread position in simdgroup matrix
+ short qid = simd_lane_id / 4;
+ slid = simd_lane_id;
+ sid = simd_group_id;
+
+ sm = (qid & 4) + (simd_lane_id / 2) % 4;
+ sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
+
+ // Determine thread and simdgroup offset
+ As_offset =
+ transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
+ Bs_offset =
+ transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
+ }
+
+ /* (BM, BK) X (BK, BN) multiply accumulate function */
+ METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
+ // Adjust for simdgroup and thread location
+ As += As_offset;
+ Bs += Bs_offset;
+
+ // Iterate over BK in blocks of 8
+ STEEL_PRAGMA_UNROLL
+ for (short kk = 0; kk < BK; kk += 8) {
+ simdgroup_barrier(mem_flags::mem_none);
+
+ // Load elements from threadgroup A as simdgroup matrices
+ STEEL_PRAGMA_UNROLL
+ for (short i = 0; i < TM; i++) {
+ Asimd[i].thread_elements()[0] =
+ static_cast<AccumType>(As[i * simd_stride_a + 0]);
+ Asimd[i].thread_elements()[1] =
+ static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
+ }
+
+ simdgroup_barrier(mem_flags::mem_none);
+
+ // Load elements from threadgroup B as simdgroup matrices
+ STEEL_PRAGMA_UNROLL
+ for (short j = 0; j < TN; j++) {
+ Bsimd[j].thread_elements()[0] =
+ static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
+ Bsimd[j].thread_elements()[1] =
+ static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
+ }
+
+ simdgroup_barrier(mem_flags::mem_none);
+
+ // Multiply and accumulate into result simdgroup matrices
+ STEEL_PRAGMA_UNROLL
+ for (short i = 0; i < TM; i++) {
+ STEEL_PRAGMA_UNROLL
+ for (short j = 0; j < TN; j++) {
+ short j_serp = (i % 2) ? (TN - 1 - j) : j;
+
+ simdgroup_multiply_accumulate(
+ results[i * TN + j_serp],
+ Asimd[i],
+ Bsimd[j_serp],
+ results[i * TN + j_serp]);
+ }
+ }
+
+ // Progress to next simdgroup tile
+ As += tile_stride_a;
+ Bs += tile_stride_b;
+ }
+ }
+
+ METAL_FUNC void rescale_output(const threadgroup float* Corrections) {
+ // Loop over all simdgroup tiles
+
+ STEEL_PRAGMA_UNROLL
+ for (short i = 0; i < TM; i++) {
+ short row = sm + tm + i * TM_stride;
+ float scale_value = Corrections[row];
+
+ STEEL_PRAGMA_UNROLL
+ for (short j = 0; j < TN; j++) {
+ // Get accumulated result and associated offset in C
+ thread auto& accum = results[i * TN + j].thread_elements();
+ // int offset = (i * TM_stride) * ldc + (j * TN_stride);
+ accum[0] *= scale_value;
+ accum[1] *= scale_value;
+ }
+ }
+ }
+
+ /* Store results from simdgroup_matrix results into device memory */
+ METAL_FUNC void store_result(device U* C, const int ldc) const {
+ // Adjust for simdgroup and thread location
+ C += (sm + tm) * ldc + tn + sn;
+
+ // Loop over all simdgroup tiles
+ STEEL_PRAGMA_UNROLL
+ for (short i = 0; i < TM; i++) {
+ STEEL_PRAGMA_UNROLL
+ for (short j = 0; j < TN; j++) {
+ // Get accumulated result and associated offset in C
+ thread const auto& accum = results[i * TN + j].thread_elements();
+ int offset = (i * TM_stride) * ldc + (j * TN_stride);
+
+ // Apply epilogue
+ U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
+
+ // Write out C
+ C[offset] = outs[0];
+ C[offset + 1] = outs[1];
+ }
+ }
+ }
+
+ METAL_FUNC void store_result_to_tgp_memory(
+ threadgroup U* C,
+ const int ldc,
+ short2 dst_tile_dims) const {
+ // Adjust for simdgroup and thread location
+ C += (sm + tm) * ldc + (tn + sn);
+ dst_tile_dims -= short2(tn + sn, sm + tm);
+
+ STEEL_PRAGMA_UNROLL
+ for (int i = 0; i < TM; i++) {
+ if (i * TM_stride < dst_tile_dims.y) {
+ STEEL_PRAGMA_UNROLL
+ for (int j = 0; j < TN; j++) {
+ // Get accumulated result and associated offset in C
+ thread const auto& accum = results[i * TN + j].thread_elements();
+ int offset = (i * TM_stride) * ldc + (j * TN_stride);
+
+ // Apply epilogue and output C
+ if (j * TN_stride < dst_tile_dims.x) {
+ C[offset] = Epilogue::apply(accum[0]);
+ }
+
+ if (j * TN_stride + 1 < dst_tile_dims.x) {
+ C[offset + 1] = Epilogue::apply(accum[1]);
+ }
+ }
+ }
+ }
+ }
+
+ METAL_FUNC void
+ store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const {
+ // Adjust for simdgroup and thread location
+ C += (sm + tm) * ldc + (tn + sn);
+ dst_tile_dims -= short2(tn + sn, sm + tm);
+
+ STEEL_PRAGMA_UNROLL
+ for (int i = 0; i < TM; i++) {
+ if (i * TM_stride < dst_tile_dims.y) {
+ STEEL_PRAGMA_UNROLL
+ for (int j = 0; j < TN; j++) {
+ // Get accumulated result and associated offset in C
+ thread const auto& accum = results[i * TN + j].thread_elements();
+ int offset = (i * TM_stride) * ldc + (j * TN_stride);
+
+ // Apply epilogue and output C
+ if (j * TN_stride < dst_tile_dims.x) {
+ C[offset] = Epilogue::apply(accum[0]);
+ }
+
+ if (j * TN_stride + 1 < dst_tile_dims.x) {
+ C[offset + 1] = Epilogue::apply(accum[1]);
+ }
+ }
+ }
+ }
+ }
+
+ /* Store results from simdgroup_matrix results into device memory */
+ METAL_FUNC void store_result(
+ device U* D,
+ const int ldd,
+ const device U* C,
+ const int ldc,
+ const int fdc,
+ thread const Epilogue& epilogue_op) const {
+ // Adjust for simdgroup and thread location
+ C += (sm + tm) * ldc + (tn + sn) * fdc;
+ D += (sm + tm) * ldd + tn + sn;
+
+ // Loop over all simdgroup tiles
+ STEEL_PRAGMA_UNROLL
+ for (short i = 0; i < TM; i++) {
+ STEEL_PRAGMA_UNROLL
+ for (short j = 0; j < TN; j++) {
+ // Get accumulated result and associated offset in C
+ thread const auto& accum = results[i * TN + j].thread_elements();
+ int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
+ int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
+
+ // Apply epilogue
+ U outs[2] = {
+ epilogue_op.apply(accum[0], C[offset_c]),
+ epilogue_op.apply(accum[1], C[offset_c + fdc])};
+
+ // Write out D
+ D[offset_d] = outs[0];
+ D[offset_d + 1] = outs[1];
+ }
+ }
+ }
+
+ METAL_FUNC void store_result_safe(
+ device U* D,
+ const int ldd,
+ const device U* C,
+ const int ldc,
+ const int fdc,
+ short2 dst_tile_dims,
+ thread const Epilogue& epilogue_op) const {
+ // Adjust for simdgroup and thread location
+ C += (sm + tm) * ldc + (tn + sn) * fdc;
+ D += (sm + tm) * ldd + tn + sn;
+ dst_tile_dims -= short2(tn + sn, sm + tm);
+
+ STEEL_PRAGMA_UNROLL
+ for (int i = 0; i < TM; i++) {
+ if (i * TM_stride < dst_tile_dims.y) {
+ STEEL_PRAGMA_UNROLL
+ for (int j = 0; j < TN; j++) {
+ // Get accumulated result and associated offset in C
+ thread const auto& accum = results[i * TN + j].thread_elements();
+ int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
+ int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
+
+ // Apply epilogue and output C
+ if (j * TN_stride < dst_tile_dims.x) {
+ D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
+ }
+
+ if (j * TN_stride + 1 < dst_tile_dims.x) {
+ D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
+ }
+ }
+ }
+ }
+ }
+
+ METAL_FUNC void clear_results() {
+ STEEL_PRAGMA_UNROLL
+ for (int i = 0; i < TM; i++) {
+ STEEL_PRAGMA_UNROLL
+ for (int j = 0; j < TN; j++) {
+ results[i * TN + j] = simdgroup_matrix<AccumType, 8, 8>(0);
+ }
+ }
+ }
+};
+
+template <
+ typename T,
+ typename U,
+ int BM,
+ int BN,
+ int BK,
+ int WM,
+ int WN,
+ bool transpose_q,
+ bool transpose_k,
+ bool transpose_v,
+ bool MN_aligned,
+ bool K_aligned,
+ typename AccumType = typename AccumHelper<T>::accum_type,
+ typename Epilogue = TransformNone<U, AccumType>>
+struct FastAttentionKernel {
+ STEEL_CONST short tgp_padding = 16 / sizeof(T);
+ STEEL_CONST short float_padding = 16 / sizeof(float);
+ STEEL_CONST short tgp_mem_size_q =
+ transpose_q ? BK * (BM + tgp_padding) : BM * (BK + tgp_padding);
+ STEEL_CONST short tgp_mem_size_k =
+ transpose_k ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding);
+ STEEL_CONST short tgp_mem_size_v =
+ transpose_v ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding);
+ STEEL_CONST short tgp_mem_size_s = BM * (BN + tgp_padding);
+
+ // maxes, rowsums, rescale
+ STEEL_CONST short tgp_mem_size_corrections =
+ 4 * (BM * sizeof(float) + float_padding);
+
+ STEEL_CONST bool share_kv_smem = transpose_k != transpose_v;
+
+ STEEL_CONST short tgp_mem_size = share_kv_smem
+ ? tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s +
+ tgp_mem_size_corrections
+ : tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s +
+ tgp_mem_size_corrections + tgp_mem_size_v;
+
+ STEEL_CONST short tgp_size = WM * WN * 32;
+
+ static_assert(transpose_q == false, "Expected Q not transposed.");
+ static_assert(transpose_k == true, "Expected K transposed.");
+ static_assert(transpose_v == false, "Expected V not transposed.");
+ static_assert(tgp_mem_size <= 32768, "Excessive tgp memory requested.");
+
+ using loader_q_t = BlockLoaderFA<
+ T,
+ transpose_q ? BK : BM,
+ transpose_q ? BM : BK,
+ transpose_q ? BM + tgp_padding : BK + tgp_padding,
+ !transpose_q,
+ tgp_size>;
+
+ using loader_k_t = BlockLoaderFA<
+ T,
+ transpose_k ? BN : BK,
+ transpose_k ? BK : BN,
+ transpose_k ? BK + tgp_padding : BN + tgp_padding,
+ transpose_k,
+ tgp_size>;
+
+ using loader_v_t = BlockLoaderFA<
+ T,
+ transpose_v ? BK : BN,
+ transpose_v ? BN : BK,
+ transpose_v ? BN + tgp_padding : BK + tgp_padding,
+ transpose_v,
+ tgp_size>;
+
+ using mma_qk_t = BlockMMAFA<
+ T,
+ U,
+ BM,
+ BN,
+ BK,
+ WM,
+ WN,
+ transpose_q,
+ transpose_k,
+ transpose_q ? BM + tgp_padding : BK + tgp_padding,
+ transpose_k ? BK + tgp_padding : BN + tgp_padding,
+ AccumType,
+ Epilogue>;
+
+ using mma_sv_t = BlockMMAFA<
+ T,
+ U,
+ BM,
+ BK,
+ BN,
+ WM,
+ WN,
+ false,
+ transpose_v,
+ BN + tgp_padding,
+ BK + tgp_padding,
+ AccumType,
+ Epilogue>;
+
+ /* Main kernel function */
+ template <bool M_aligned, bool N_aligned, bool K_aligned_>
+ static METAL_FUNC void gemm_loop(
+ threadgroup T* As [[threadgroup(0)]],
+ threadgroup T* Bs [[threadgroup(1)]],
+ const int gemm_k_iterations,
+ thread loader_k_t& loader_b,
+ thread mma_qk_t& mma_op,
+ thread const short& tgp_bm,
+ thread const short& tgp_bn,
+ LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
+ // Appease the compiler
+ (void)l;
+ (void)tgp_bm;
+
+ short2 tile_dims_B = transpose_k ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
+
+ // not valid for gemm_k_iterations > 1 (so, BK == d_k)
+ for (int k = 0; k < gemm_k_iterations; k++) {
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (N_aligned) {
+ loader_b.load_unsafe();
+ } else {
+ loader_b.load_safe(tile_dims_B);
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // Multiply and accumulate threadgroup elements
+ mma_op.mma(As, Bs);
+ }
+ }
+
+ static METAL_FUNC void initialize_corrections(
+ threadgroup float* C,
+ uint simd_lane_id,
+ uint simd_group_id) {
+ if (simd_group_id == 0) {
+ threadgroup float* maxes = C;
+ threadgroup float* sums = C + (BM + float_padding);
+ threadgroup float* o_rescale = sums + (BM + float_padding);
+ threadgroup float* output_rescale = o_rescale + (BM + float_padding);
+
+ if (simd_lane_id < BM) {
+ maxes[simd_lane_id] = -INFINITY; // m_i
+ sums[simd_lane_id] = 0.f; // l_i
+ o_rescale[simd_lane_id] = 1.f; // li * exp(mi - mi_new)
+ output_rescale[simd_lane_id] = 1.f; // 1.0 / l_i
+ }
+ }
+ }
+
+ static METAL_FUNC void rescale_ss(
+ threadgroup T* Ss,
+ threadgroup float* Corrections,
+ uint simd_group_id,
+ uint simd_lane_id,
+ short2 local_blocks,
+ float alpha,
+ float softcapping) {
+ if (simd_group_id == 0) {
+ short row_offset = BM + float_padding;
+ threadgroup float* maxes = Corrections;
+ threadgroup float* sums = Corrections + row_offset;
+ threadgroup float* o_rescale = sums + row_offset;
+ threadgroup float* output_scales = o_rescale + row_offset;
+
+ if (simd_lane_id < uint(local_blocks.y)) {
+ float m_i_old = maxes[simd_lane_id];
+ float l_i_old = sums[simd_lane_id];
+
+ float m_i_new = m_i_old;
+ float l_i_new = l_i_old;
+
+ short offset = simd_lane_id * (BN + tgp_padding);
+
+ float m_ij = -INFINITY;
+
+ for (short j = 0; j < local_blocks.x; j++) {
+ float val = alpha * float(Ss[offset + j]);
+ if (softcapping != 1.) {
+ val = precise::tanh(val);
+ val = val * softcapping;
+ }
+ m_ij = max(m_ij, val);
+ }
+
+ m_i_new = max(m_ij, m_i_new);
+
+ float rowsum = 0.f; // lij
+
+ for (short j = 0; j < local_blocks.x; j++) {
+ float val = alpha * float(Ss[offset + j]);
+ if (softcapping != 1.) {
+ val = precise::tanh(val);
+ val = val * softcapping;
+ }
+ float P_i_j = exp(val - m_ij);
+ rowsum += P_i_j;
+ P_i_j = P_i_j * exp(m_ij - m_i_new);
+ Ss[offset + j] = T(P_i_j);
+ }
+
+ l_i_new =
+ exp(m_i_old - m_i_new) * l_i_old + exp(m_ij - m_i_new) * rowsum;
+ maxes[simd_lane_id] = m_i_new;
+ sums[simd_lane_id] = l_i_new;
+ float rescale = l_i_old * exp(m_i_old - m_i_new);
+ o_rescale[simd_lane_id] = rescale;
+ output_scales[simd_lane_id] = 1.0 / l_i_new;
+ }
+ }
+ }
+
+ /* Main kernel function */
+ static METAL_FUNC void run(
+ const device T* Q [[buffer(0)]],
+ const device T* K [[buffer(1)]],
+ const device T* V [[buffer(2)]],
+ device U* O [[buffer(3)]],
+ const constant MLXFastAttentionParams* params [[buffer(4)]],
+ threadgroup T* Qs [[threadgroup(0)]],
+ threadgroup T* Ks [[threadgroup(1)]],
+ threadgroup T* Ss [[threadgroup(2)]],
+ threadgroup T* Vs [[threadgroup(3)]],
+ threadgroup float* Corrections [[threadgroup(4)]],
+ uint simd_lane_id [[thread_index_in_simdgroup]],
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
+ uint3 tid [[threadgroup_position_in_grid]],
+ uint3 lid [[thread_position_in_threadgroup]]) {
+ // Pacifying compiler
+ (void)lid;
+
+ const int tid_y = ((tid.y) << params->swizzle_log) +
+ ((tid.x) & ((1 << params->swizzle_log) - 1));
+ const int tid_x = (tid.x) >> params->swizzle_log;
+
+ if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
+ return;
+ }
+
+ threadgroup_barrier(mem_flags::mem_none);
+
+ // Find block in Q, O; and head in K, V.
+ const int c_row = tid_y * BM;
+
+ Q += transpose_q ? c_row : c_row * params->ldq;
+ thread loader_q_t loader_q(Q, params->ldq, Qs, simd_group_id, simd_lane_id);
+
+ short tgp_bm = min(BM, params->M - c_row);
+ short2 tile_dims_Q = transpose_q ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
+
+ loader_q.load_safe(tile_dims_Q);
+
+ initialize_corrections(Corrections, simd_lane_id, simd_group_id);
+
+ O += c_row * params->ldo;
+
+ // Prepare threadgroup mma operation
+ thread mma_qk_t mma_qk_op(simd_group_id, simd_lane_id);
+ thread mma_sv_t mma_softmax_sv_op(simd_group_id, simd_lane_id);
+ thread loader_k_t loader_k(K, params->ldk, Ks, simd_group_id, simd_lane_id);
+ thread loader_v_t loader_v(V, params->ldv, Vs, simd_group_id, simd_lane_id);
+
+ for (short n_block = 0; n_block < params->gemm_n_iterations_aligned;
+ n_block++) {
+ short c_col = BN;
+
+ // Prepare threadgroup loading operations
+ short gemm_k_iterations = params->gemm_k_iterations_aligned;
+ short tgp_bn_qk = min(BN, params->N - c_col * n_block);
+ threadgroup_barrier(mem_flags::mem_none);
+
+ ///////////////////////////////////////////////////////////////////////////////
+ { // Loop over K - unaligned case
+
+ if (tgp_bm == BM && tgp_bn_qk == BN) {
+ gemm_loop<true, true, K_aligned>(
+ Qs,
+ Ks,
+ gemm_k_iterations,
+ loader_k,
+ mma_qk_op,
+ tgp_bm,
+ tgp_bn_qk);
+ } else if (tgp_bn_qk == BN) {
+ gemm_loop<false, true, K_aligned>(
+ Qs,
+ Ks,
+ gemm_k_iterations,
+ loader_k,
+ mma_qk_op,
+ tgp_bm,
+ tgp_bn_qk);
+
+ } else if (tgp_bm == BM) {
+ gemm_loop<true, false, K_aligned>(
+ Qs,
+ Ks,
+ gemm_k_iterations,
+ loader_k,
+ mma_qk_op,
+ tgp_bm,
+ tgp_bn_qk);
+
+ } else {
+ gemm_loop<false, false, K_aligned>(
+ Qs,
+ Ks,
+ gemm_k_iterations,
+ loader_k,
+ mma_qk_op,
+ tgp_bm,
+ tgp_bn_qk);
+ }
+ }
+
+ mma_qk_op.store_result_to_tgp_memory(
+ Ss, BN + tgp_padding, short2(BN, BM));
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ rescale_ss(
+ Ss,
+ Corrections,
+ simd_group_id,
+ simd_lane_id,
+ short2(tgp_bn_qk, tgp_bm),
+ params->alpha,
+ params->softcapping);
+
+ loader_v.load_safe(short2(BK, tgp_bn_qk));
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ threadgroup float* o_scales = Corrections + 2 * (BM + float_padding);
+ mma_softmax_sv_op.rescale_output(o_scales);
+
+ mma_softmax_sv_op.mma(Ss, Vs);
+
+ threadgroup float* final_output_scales =
+ Corrections + 3 * (BM + float_padding);
+
+ mma_softmax_sv_op.rescale_output(final_output_scales);
+
+ loader_v.next();
+ loader_k.next(BN);
+
+ mma_qk_op.clear_results();
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ mma_softmax_sv_op.store_result_safe(O, params->ldo, short2(BK, tgp_bm));
+ }
+};
+
+template <
+ typename T,
+ int BM,
+ int BN,
+ int BK,
+ int WM,
+ int WN,
+ bool transpose_q,
+ bool transpose_k,
+ bool transpose_v,
+ bool MN_aligned,
+ bool K_aligned>
+[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void attention(
+ const device T* Q [[buffer(0)]],
+ const device T* K [[buffer(1)]],
+ const device T* V [[buffer(2)]],
+ device T* O [[buffer(3)]],
+ const constant MLXFastAttentionParams* params [[buffer(4)]],
+ const constant int* batch_shape [[buffer(6)]],
+ const constant size_t* batch_strides [[buffer(7)]],
+ uint simd_lane_id [[thread_index_in_simdgroup]],
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
+ uint3 tid [[threadgroup_position_in_grid]],
+ uint3 lid [[thread_position_in_threadgroup]]) {
+ using attention_kernel = FastAttentionKernel<
+ T,
+ T,
+ BM,
+ BN,
+ BK,
+ WM,
+ WN,
+ transpose_q,
+ transpose_k,
+ transpose_v,
+ MN_aligned,
+ K_aligned>;
+
+ // Adjust for batch
+ if (params->batch_ndim > 1) {
+ const constant size_t* Q_bstrides = batch_strides;
+ const constant size_t* KV_bstrides = batch_strides + params->batch_ndim;
+
+ ulong2 batch_offsets = elem_to_loc_broadcast(
+ tid.z, batch_shape, Q_bstrides, KV_bstrides, params->batch_ndim);
+
+ Q += batch_offsets.x;
+ K += batch_offsets.y;
+ V += batch_offsets.y;
+
+ } else {
+ Q += params->batch_stride_q * tid.z;
+ K += params->batch_stride_k * tid.z;
+ V += params->batch_stride_v * tid.z;
+ }
+
+ // same shape as input
+ O += params->batch_stride_o * tid.z;
+ threadgroup T Qs[attention_kernel::tgp_mem_size_q];
+ threadgroup T Ss[attention_kernel::tgp_mem_size_s];
+ threadgroup float Corrections[attention_kernel::tgp_mem_size_corrections];
+
+ if (attention_kernel::share_kv_smem) {
+ threadgroup T Ks[attention_kernel::tgp_mem_size_k];
+ threadgroup T* Vs = Ks; //[attention_kernel::tgp_mem_size_v];
+ attention_kernel::run(
+ Q,
+ K,
+ V,
+ O,
+ params,
+ Qs,
+ Ks,
+ Ss,
+ Vs,
+ Corrections,
+ simd_lane_id,
+ simd_group_id,
+ tid,
+ lid);
+ } else {
+ threadgroup T Ks[attention_kernel::tgp_mem_size_k];
+ threadgroup T Vs[attention_kernel::tgp_mem_size_v];
+ attention_kernel::run(
+ Q,
+ K,
+ V,
+ O,
+ params,
+ Qs,
+ Ks,
+ Ss,
+ Vs,
+ Corrections,
+ simd_lane_id,
+ simd_group_id,
+ tid,
+ lid);
+ }
+}
+
+// clang-format off
+
+// SDPA full instantiations
+#define instantiate_fast_inference_self_attention_kernel( \
+ itype, otype, bm, bn, bk, wm, wn) \
+ template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \
+ "_itype_" #itype)]] [[kernel]] void \
+ attention<itype, bm, bn, bk, wm, wn, false, true, false, false, true>( \
+ const device itype* Q [[buffer(0)]], \
+ const device itype* K [[buffer(1)]], \
+ const device itype* V [[buffer(2)]], \
+ device otype* O [[buffer(3)]], \
+ const constant MLXFastAttentionParams* params [[buffer(4)]], \
+ const constant int* batch_shape [[buffer(5)]], \
+ const constant size_t* batch_strides [[buffer(6)]], \
+ uint simd_lane_id [[thread_index_in_simdgroup]], \
+ uint simd_group_id [[simdgroup_index_in_threadgroup]], \
+ uint3 tid [[threadgroup_position_in_grid]], \
+ uint3 lid [[thread_position_in_threadgroup]]);
+
+instantiate_fast_inference_self_attention_kernel(
+ float,
+ float,
+ 16,
+ 16,
+ 32,
+ 2,
+ 2);
+instantiate_fast_inference_self_attention_kernel(
+ float,
+ float,
+ 16,
+ 16,
+ 64,
+ 2,
+ 2);
+instantiate_fast_inference_self_attention_kernel(
+ float,
+ float,
+ 16,
+ 16,
+ 96,
+ 2,
+ 2);
+instantiate_fast_inference_self_attention_kernel(
+ float,
+ float,
+ 16,
+ 16,
+ 128,
+ 2,
+ 2);
+instantiate_fast_inference_self_attention_kernel(
+ float,
+ float,
+ 16,
+ 16,
+ 256,
+ 2,
+ 2);
+instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 32, 2, 2);
+instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2);
+instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 96, 2, 2);
+instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2);
+instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2);
+
+// SDPA vector instantiations
+#define instantiate_sdpa_vector(type, head_dim) \
+ template [[host_name("sdpa_vector_" #type "_" #head_dim)]] \
+ [[kernel]] void sdpa_vector<type, head_dim>( \
+ const device type* queries [[buffer(0)]], \
+ const device type* keys [[buffer(1)]], \
+ const device type* values [[buffer(2)]], \
+ device type* out [[buffer(3)]], \
+ const constant int& gqa_factor, \
+ const constant int& N, \
+ const constant size_t& k_stride, \
+ const constant size_t& v_stride, \
+ const constant float& scale, \
+ const constant float& softcapping, \
+ uint3 tid [[threadgroup_position_in_grid]], \
+ uint simd_gid [[simdgroup_index_in_threadgroup]], \
+ uint simd_lid [[thread_index_in_simdgroup]]);
+
+#define instantiate_sdpa_vector_heads(type) \
+ instantiate_sdpa_vector(type, 32) \
+ instantiate_sdpa_vector(type, 64) \
+ instantiate_sdpa_vector(type, 96) \
+ instantiate_sdpa_vector(type, 128) \
+ instantiate_sdpa_vector(type, 256)
+
+instantiate_sdpa_vector_heads(float)
+#if defined(__HAVE_BFLOAT__)
+instantiate_sdpa_vector_heads(bfloat16_t)
+#endif
+instantiate_sdpa_vector_heads(float16_t)
+ // clang-format on