summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/lib.rs')
-rw-r--r--candle-metal-kernels/src/lib.rs323
1 files changed, 322 insertions, 1 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 222ae8ad..0843cc11 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -8,7 +8,7 @@ use std::sync::RwLock;
pub mod utils;
pub use utils::BufferOffset;
-use utils::{get_block_dims, linear_split, EncoderProvider};
+use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider};
const AFFINE: &str = include_str!("affine.metal");
const BINARY: &str = include_str!("binary.metal");
@@ -25,6 +25,7 @@ const REDUCE: &str = include_str!("reduce.metal");
const SORT: &str = include_str!("sort.metal");
const TERNARY: &str = include_str!("ternary.metal");
const UNARY: &str = include_str!("unary.metal");
+const SDPA: &str = include_str!("scaled_dot_product_attention.metal");
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
@@ -42,6 +43,7 @@ pub enum Source {
Sort,
Ternary,
Unary,
+ Sdpa,
}
pub mod copy2d {
@@ -159,6 +161,17 @@ pub enum MetalKernelError {
rhs_stride: Vec<usize>,
mnk: (usize, usize, usize),
},
+ #[error("Sdpa {variation} head size was {got}, expectd {expected:?}")]
+ SdpaHeadSizeMismatch {
+ variation: &'static str,
+ got: usize,
+ expected: Vec<usize>,
+ },
+ #[error("Sdpa {variation} got dtype {got:?}")]
+ SdpaHeadDTypeMismatch {
+ variation: &'static str,
+ got: SdpaDType,
+ },
}
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
@@ -207,6 +220,7 @@ impl Kernels {
Source::Sort => SORT,
Source::Ternary => TERNARY,
Source::Unary => UNARY,
+ Source::Sdpa => SDPA,
Source::Mfa => panic!("Invalid lib"),
}
}
@@ -1627,6 +1641,313 @@ pub fn call_gemm(
Ok(())
}
+#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
+pub enum SdpaDType {
+ BF16,
+ F16,
+ F32,
+}
+
+/// SDPA full is supported when:
+/// - q head dim == 64, 128
+/// - no mask
+/// - q heads == kv heads
+/// - final type != bf16 (TODO maybe just template this kernel too?)
+/// - q,k,v are contiguous
+#[allow(clippy::too_many_arguments)]
+pub fn call_sdpa_full(
+ device: &Device,
+ ep: impl EncoderProvider,
+ kernels: &Kernels,
+ q_offset: usize,
+ q_shape: &[usize],
+ q_buffer: &Buffer,
+ k_offset: usize,
+ k_buffer: &Buffer,
+ v_offset: usize,
+ v_buffer: &Buffer,
+ output: &Buffer,
+ alpha: f32,
+ softcapping: f32,
+ itype: SdpaDType,
+) -> Result<(), MetalKernelError> {
+ #[derive(Debug)]
+ #[repr(C)]
+ struct MLXFastAttentionParams {
+ m: i32,
+ n: i32,
+ k: i32,
+
+ ldq: i32, // ldq == ldo
+ ldk: i32,
+ ldv: i32,
+ lds: i32,
+ ldo: i32,
+
+ tiles_n: i32,
+ tiles_m: i32,
+
+ batch_stride_q: i32,
+ batch_stride_k: i32,
+ batch_stride_v: i32,
+ batch_stride_o: i32,
+
+ swizzle_log: i32,
+ gemm_n_iterations_aligned: i32,
+ gemm_k_iterations_aligned: i32,
+ gemm_sv_m_block_iterations: i32,
+
+ batch_ndim: i32,
+ alpha: f32,
+ softcapping: f32,
+ }
+
+ let bk = q_shape.last().unwrap();
+
+ const BN: usize = 16;
+ const BM: usize = 16;
+ const WM: usize = 2;
+ const WN: usize = 2;
+
+ let name = match (bk, itype) {
+ (32, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_half",
+ (64, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_half",
+ (96, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_half",
+ (128, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_half",
+ (256, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_half",
+ (32, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_float",
+ (64, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_float",
+ (96, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_float",
+ (128, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_float",
+ (256, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_float",
+ (other, SdpaDType::F16 | SdpaDType::F32) => {
+ return Err(MetalKernelError::SdpaHeadSizeMismatch {
+ variation: "full",
+ got: *other,
+ expected: vec![32, 64, 96, 128, 256],
+ })
+ }
+ (_, SdpaDType::BF16) => {
+ return Err(MetalKernelError::SdpaHeadDTypeMismatch {
+ variation: "full",
+ got: SdpaDType::BF16,
+ })
+ }
+ };
+
+ let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?;
+ let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ // q = (bs, qhead, seq, hidden)
+ // k/v = (bs, kv_head, seq, hidden)
+
+ let qseq = q_shape[q_shape.len() - 2];
+
+ let m = q_shape[q_shape.len() - 2];
+ let n = m;
+ let k = q_shape[q_shape.len() - 1];
+ let bs_out = q_shape[0] * q_shape[1];
+
+ let batch_shape = [q_shape[0] * q_shape[1]];
+ let dk = q_shape[q_shape.len() - 1];
+ let ldq = dk;
+ let ldk = dk;
+ let ldv = dk;
+ let lds = BN;
+ let ldo = dk;
+
+ let tn = 1;
+ let tm = (m + BM - 1) / BM;
+
+ let b_stride_q = dk * qseq;
+ let b_stride_k = dk * qseq;
+ let b_stride_v = dk * qseq;
+ let b_stride_o = dk * qseq;
+ let swizzle_log = 0;
+ let gemm_n_iterations_aligned = (n + BN - 1) / BN;
+ let gemm_k_iterations_aligned = (k + bk - 1) / bk;
+ let gemm_sv_m_block_iterations = (m + BM - 1) / BM;
+ let batch_ndim = batch_shape.len();
+
+ let alpha = if softcapping != 1. {
+ alpha / softcapping
+ } else {
+ alpha
+ };
+
+ let params = MLXFastAttentionParams {
+ m: m as i32,
+ n: n as i32,
+ k: k as i32,
+ ldq: ldq as i32,
+ ldk: ldk as i32,
+ ldv: ldv as i32,
+ lds: lds as i32,
+ ldo: ldo as i32,
+ tiles_n: tn,
+ tiles_m: tm as i32,
+ batch_stride_q: b_stride_q as i32,
+ batch_stride_k: b_stride_k as i32,
+ batch_stride_v: b_stride_v as i32,
+ batch_stride_o: b_stride_o as i32,
+ swizzle_log,
+ gemm_n_iterations_aligned: gemm_n_iterations_aligned as i32,
+ gemm_k_iterations_aligned: gemm_k_iterations_aligned as i32,
+ gemm_sv_m_block_iterations: gemm_sv_m_block_iterations as i32,
+ batch_ndim: batch_ndim as i32,
+ alpha,
+ softcapping,
+ };
+ let batch_strides = [b_stride_q, b_stride_k, b_stride_v, b_stride_o];
+
+ impl EncoderParam for MLXFastAttentionParams {
+ fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
+ encoder.set_bytes(
+ position,
+ core::mem::size_of::<MLXFastAttentionParams>() as u64,
+ &data as *const MLXFastAttentionParams as *const c_void,
+ );
+ }
+ }
+
+ set_params!(
+ encoder,
+ (
+ (q_buffer, q_offset),
+ (k_buffer, k_offset),
+ (v_buffer, v_offset),
+ output,
+ params,
+ &batch_shape[..],
+ &batch_strides[..]
+ )
+ );
+
+ let grid_dims = MTLSize {
+ width: 1,
+ height: tm as u64,
+ depth: bs_out as u64,
+ };
+ let group_dims = MTLSize {
+ width: 32,
+ height: WM as u64,
+ depth: WN as u64,
+ };
+ encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(grid_dims, group_dims);
+ Ok(())
+}
+
+/// SDPA full is supported when:
+/// - q head dim == 64, 96, 128
+/// - no mask
+/// - q,k,v are contiguous
+#[allow(clippy::too_many_arguments)]
+pub fn call_sdpa_vector(
+ device: &Device,
+ ep: impl EncoderProvider,
+ kernels: &Kernels,
+ q_offset: usize,
+ q_shape: &[usize],
+ q_buffer: &Buffer,
+ k_offset: usize,
+ k_shape: &[usize],
+ k_stride: &[usize],
+ k_buffer: &Buffer,
+ v_offset: usize,
+ v_stride: &[usize],
+ v_buffer: &Buffer,
+ output: &Buffer,
+ alpha: f32,
+ softcapping: f32,
+ itype: SdpaDType,
+) -> Result<(), MetalKernelError> {
+ let bk = q_shape.last().unwrap();
+
+ let gqa_factor = (q_shape[1] / k_shape[1]) as i32;
+ let n = k_shape[2] as i32;
+ let b = (q_shape[0] * q_shape[1]) as i32;
+ let kstride = k_stride[1];
+ let vstride = v_stride[1];
+
+ let name = match (bk, itype) {
+ (32, SdpaDType::F16) => "sdpa_vector_float16_t_32",
+ (64, SdpaDType::F16) => "sdpa_vector_float16_t_64",
+ (96, SdpaDType::F16) => "sdpa_vector_float16_t_96",
+ (128, SdpaDType::F16) => "sdpa_vector_float16_t_128",
+ (256, SdpaDType::F16) => "sdpa_vector_float16_t_256",
+ (32, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_32",
+ (64, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_64",
+ (96, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_96",
+ (128, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_128",
+ (256, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_256",
+ (32, SdpaDType::F32) => "sdpa_vector_float_32",
+ (64, SdpaDType::F32) => "sdpa_vector_float_64",
+ (96, SdpaDType::F32) => "sdpa_vector_float_96",
+ (128, SdpaDType::F32) => "sdpa_vector_float_128",
+ (256, SdpaDType::F32) => "sdpa_vector_float_256",
+ (other, _) => {
+ return Err(MetalKernelError::SdpaHeadSizeMismatch {
+ variation: "vector",
+ got: *other,
+ expected: vec![32, 64, 96, 128, 256],
+ })
+ }
+ };
+
+ let alpha = if softcapping != 1. {
+ alpha / softcapping
+ } else {
+ alpha
+ };
+
+ let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?;
+ let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ // q = (bs, qhead, seq, hidden)
+ // k/v = (bs, kv_head, kv_seq, hidden)
+
+ set_params!(
+ encoder,
+ (
+ (q_buffer, q_offset),
+ (k_buffer, k_offset),
+ (v_buffer, v_offset),
+ output,
+ gqa_factor,
+ n,
+ kstride,
+ vstride,
+ alpha,
+ softcapping
+ )
+ );
+
+ let grid_dims = MTLSize {
+ width: 1,
+ height: b as u64,
+ depth: 1 as u64,
+ };
+ let group_dims = MTLSize {
+ width: 1024,
+ height: 1,
+ depth: 1,
+ };
+ encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(grid_dims, group_dims);
+ Ok(())
+}
+
#[allow(clippy::too_many_arguments)]
pub fn call_im2col1d_strided(
device: &Device,