diff options
author | Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> | 2024-11-05 03:28:00 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-11-05 09:28:00 +0100 |
commit | e2b6b367fa852ed30ac532f8d77cd8479c7ed092 (patch) | |
tree | 41321e646a0ee9abef88122b202bd940240ecae6 /candle-nn | |
parent | 6454597943599dd6df787a0d5f2446c5724d850a (diff) | |
download | candle-e2b6b367fa852ed30ac532f8d77cd8479c7ed092.tar.gz candle-e2b6b367fa852ed30ac532f8d77cd8479c7ed092.tar.bz2 candle-e2b6b367fa852ed30ac532f8d77cd8479c7ed092.zip |
Add some fast Metal MLX SDPA kernels (#2584)
* Add some fast Metal MLX SDPA kernels (#32)
* Sketch the sdpa kernel
* Add full sdpa kernel,
* Add test
* Add vectorized kernel for decoding
* Update tests
* Add some docs
* Fix sdpa_vector names
* Add softcapping for vectorized sdpa
* Add softcapping for full sdpa
* Add support for head dim 32, 96, 256
* Add support for head dim 32, 96, 256
* Update docs
* Add update notice
* Clippy and format
* Conditional compilation for bf16
* Use it in quantized llama
* Some review comments
* Use set_params!
* Remove unused
* Remove feature
* Fix metal sdpa for v stride
* Remove comma
* Add the dim method to layout and shape.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/ops.rs | 190 | ||||
-rw-r--r-- | candle-nn/tests/sdpa.rs | 206 |
2 files changed, 396 insertions, 0 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 8a3c19fe..0f35285d 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -964,3 +964,193 @@ impl Module for Identity { Ok(xs.clone()) } } + +#[allow(dead_code)] +struct Sdpa { + scale: f32, + softcapping: f32, +} + +impl candle::CustomOp3 for Sdpa { + fn name(&self) -> &'static str { + "metal-sdpa" + } + + fn cpu_fwd( + &self, + _s1: &CpuStorage, + _l1: &Layout, + _s2: &CpuStorage, + _l2: &Layout, + _s3: &CpuStorage, + _l3: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle::bail!("SDPA has no cpu impl") + } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + q: &candle::MetalStorage, + q_l: &Layout, + k: &candle::MetalStorage, + k_l: &Layout, + v: &candle::MetalStorage, + v_l: &Layout, + ) -> Result<(candle::MetalStorage, Shape)> { + use candle::backend::BackendStorage; + use candle_metal_kernels::SdpaDType; + + let device = q.device(); + + let out_dims = vec![q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, v_l.dim(3)?]; + let elem_count: usize = out_dims.iter().product(); + + let output = device.new_buffer(elem_count, q.dtype(), "sdpa_o")?; + + // q,k must have matching emb dim + if q_l.dim(D::Minus1)? != k_l.dim(D::Minus1)? { + candle::bail!("`q` and `k` last dims must match"); + } + + // k,v must have matching n kv heads + if v_l.dim(D::Minus(3))? != k_l.dim(D::Minus(3))? { + candle::bail!("`k` and `v` head dims must match"); + } + + // n_heads % n_kv_heads == 0; n_heads >= 1, n_kv_heads >= 1. + if q_l.dim(D::Minus(3))? % k_l.dim(D::Minus(3))? != 0 { + candle::bail!("query `n_heads` must be a multiple of `n_kv_heads`"); + } + + let k_head = k_l.dim(D::Minus1)?; + let q_head = q_l.dim(D::Minus1)?; + let q_seq = q_l.dim(2)?; + + let mut implementation_supports_use_case = q_head == k_head; + let supported_head_dim = + q_head == 32 || q_head == 64 || q_head == 96 || q_head == 128 || q_head == 256; + + const SDPA_FULL_THRESHOLD: usize = 2; + + let supports_sdpa_full = + q_seq >= SDPA_FULL_THRESHOLD && supported_head_dim && q_head == k_head; + let supports_sdpa_vector = q_seq == 1 && supported_head_dim; + + implementation_supports_use_case &= supports_sdpa_full || supports_sdpa_vector; + + if !supported_head_dim { + candle::bail!( + "Meta SDPA does not support q head dim {q_head}: q dims {:?}, k dims {:?}, v dims {:?}.", + q_l.dims(), + k_l.dims(), + v_l.dims() + ); + } + if !implementation_supports_use_case { + candle::bail!( + "Meta SDPA does not support q dims {:?}, k dims {:?}, v dims {:?}.", + q_l.dims(), + k_l.dims(), + v_l.dims() + ); + } + + for t in [k.dtype(), v.dtype()] { + if q.dtype() != t { + candle::bail!("all q, k, v dtypes must match."); + } + } + + let itype = match q.dtype() { + DType::BF16 => SdpaDType::BF16, + DType::F16 => SdpaDType::F16, + DType::F32 => SdpaDType::F32, + other => candle::bail!("unsupported sdpa type {other:?}"), + }; + + let command_buffer = q.device().command_buffer()?; + if supports_sdpa_vector { + command_buffer.set_label("vector_attention"); + candle_metal_kernels::call_sdpa_vector( + q.device().device(), + &command_buffer, + q.device().kernels(), + q_l.start_offset(), + q_l.dims(), + q.buffer(), + k_l.start_offset(), + k_l.dims(), + k_l.stride(), + k.buffer(), + v_l.start_offset(), + v_l.stride(), + v.buffer(), + &output, + self.scale, + self.softcapping, + itype, + ) + .map_err(candle::Error::wrap)?; + } else if supports_sdpa_full { + if q_l.dim(2)? != k_l.dim(2)? { + candle::bail!( + "query and key sequence length must be equal if using full metal sdpa" + ) + } + + command_buffer.set_label("full_attention"); + candle_metal_kernels::call_sdpa_full( + q.device().device(), + &command_buffer, + q.device().kernels(), + q_l.start_offset(), + q_l.dims(), + q.buffer(), + k_l.start_offset(), + k.buffer(), + v_l.start_offset(), + v.buffer(), + &output, + self.scale, + self.softcapping, + itype, + ) + .map_err(candle::Error::wrap)?; + } else { + candle::bail!("must be vector or full sdpa kernel"); + } + + let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, q.dtype()); + Ok((newstorage, Shape::from_dims(&out_dims))) + } +} + +/// Scaled dot product attention with a fused kernel. +/// +/// Computes softmax(qk^T*scale)v. +/// +/// **Inputs shapes:** +/// - `q`: (bs, qhead, seq, hidden) +/// - `k`: (bs, kv_head, kv_seq, hidden) +/// - `k`: (bs, kv_head, kv_seq, v_hidden) +/// - `scale` is applied before softmax. +/// - If `softcapping` != 1.0: +/// - Computation is: softmax(tanh(qk^T*scale/cap)*cap)v +/// +/// **Output shape:** (bs, qhead, seq, v_hidden) +/// +/// **Supported head dims:** 32, 64, 96, 128, 256. +/// +/// ## On Metal: +/// - If `seq` == 1: +/// - Use a vectorized kernel +/// - Supports `seq` != `kv_seq` (cross attn. support) +/// - Supports GQA when `qhead` is a multiple of `kv_head` +/// - Otherwise: +/// - Use an alternate kernel +/// - Requires `seq` == `kv_seq` +/// - GQA is not supported (requires `qhead` == `kv_head`) +pub fn sdpa(q: &Tensor, k: &Tensor, v: &Tensor, scale: f32, softcapping: f32) -> Result<Tensor> { + q.apply_op3_no_bwd(k, v, &Sdpa { scale, softcapping }) +} diff --git a/candle-nn/tests/sdpa.rs b/candle-nn/tests/sdpa.rs new file mode 100644 index 00000000..67ad3816 --- /dev/null +++ b/candle-nn/tests/sdpa.rs @@ -0,0 +1,206 @@ +#[cfg(feature = "metal")] +mod metal_sdpa_tests { + #[test] + fn sdpa_full() -> candle::Result<()> { + use candle::{DType, Device, Tensor}; + + // Force seqlen = 100 + const BS: usize = 4; + const R: usize = 4; + const L: usize = 4; + const DK: usize = 64; + const H: usize = 3; + let scale: f64 = f64::from(DK as u32).sqrt().recip(); + + let device = Device::new_metal(0)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + + let ground_truth = { + let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; + let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)? + .to_dtype(q.dtype())?; + att.matmul(&v.clone())? + }; + + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; + + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error <= 0.0005, "{}", error); + + Ok(()) + } + + #[test] + fn sdpa_vector() -> candle::Result<()> { + use candle::{DType, Device, Tensor}; + + // Allow vectorized, seqlen = 1 + const BS: usize = 4; + const R: usize = 1; + const L: usize = 1; + const DK: usize = 64; + const H: usize = 3; + let scale: f64 = f64::from(DK as u32).sqrt().recip(); + + let device = Device::new_metal(0)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + + let ground_truth = { + let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; + let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)? + .to_dtype(q.dtype())?; + att.matmul(&v.clone())? + }; + + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; + + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error <= 0.0001, "{}", error); + + Ok(()) + } + + #[test] + fn sdpa_full_softcapping() -> candle::Result<()> { + use candle::{DType, Device, Tensor}; + use std::ops::{Div, Mul}; + + // Allow vectorized, seqlen = 1 + const BS: usize = 4; + const R: usize = 4; + const L: usize = 4; + const DK: usize = 64; + const H: usize = 3; + const SOFTCAP: f64 = 50.; + let scale: f64 = f64::from(DK as u32).sqrt().recip(); + + let device = Device::new_metal(0)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + + let ground_truth = { + let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; + let att = candle_nn::ops::softmax_last_dim( + &att.to_dtype(DType::F32)? + .div(SOFTCAP)? + .tanh()? + .mul(SOFTCAP)?, + )? + .to_dtype(q.dtype())?; + att.matmul(&v.clone())? + }; + + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?; + + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error <= 0.0004, "{}", error); + + Ok(()) + } + + #[test] + fn sdpa_vector_softcapping() -> candle::Result<()> { + use candle::{DType, Device, Tensor}; + use std::ops::{Div, Mul}; + + // Allow vectorized, seqlen = 1 + const BS: usize = 4; + const R: usize = 1; + const L: usize = 1; + const DK: usize = 64; + const H: usize = 3; + const SOFTCAP: f64 = 50.; + let scale: f64 = f64::from(DK as u32).sqrt().recip(); + + let device = Device::new_metal(0)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + + let ground_truth = { + let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; + let att = candle_nn::ops::softmax_last_dim( + &att.to_dtype(DType::F32)? + .div(SOFTCAP)? + .tanh()? + .mul(SOFTCAP)?, + )? + .to_dtype(q.dtype())?; + att.matmul(&v.clone())? + }; + + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?; + + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error <= 0.0001, "{}", error); + + Ok(()) + } + + #[test] + fn sdpa_vector_cross() -> candle::Result<()> { + use candle::{DType, Device, Tensor}; + + // Allow vectorized, seqlen = 1. Simulat cross attention case where R != L, R = 1 + const BS: usize = 4; + const R: usize = 1; + const L: usize = 24; + const DK: usize = 64; + const H: usize = 3; + let scale: f64 = f64::from(DK as u32).sqrt().recip(); + + let device = Device::new_metal(0)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + + let ground_truth = { + let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; + let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)? + .to_dtype(q.dtype())?; + att.matmul(&v.clone())? + }; + + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; + + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error <= 0.0013, "{}", error); + + Ok(()) + } +} |