summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorEric Buehler <65165915+EricLBuehler@users.noreply.github.com>2024-11-05 03:28:00 -0500
committerGitHub <noreply@github.com>2024-11-05 09:28:00 +0100
commite2b6b367fa852ed30ac532f8d77cd8479c7ed092 (patch)
tree41321e646a0ee9abef88122b202bd940240ecae6 /candle-nn
parent6454597943599dd6df787a0d5f2446c5724d850a (diff)
downloadcandle-e2b6b367fa852ed30ac532f8d77cd8479c7ed092.tar.gz
candle-e2b6b367fa852ed30ac532f8d77cd8479c7ed092.tar.bz2
candle-e2b6b367fa852ed30ac532f8d77cd8479c7ed092.zip
Add some fast Metal MLX SDPA kernels (#2584)
* Add some fast Metal MLX SDPA kernels (#32) * Sketch the sdpa kernel * Add full sdpa kernel, * Add test * Add vectorized kernel for decoding * Update tests * Add some docs * Fix sdpa_vector names * Add softcapping for vectorized sdpa * Add softcapping for full sdpa * Add support for head dim 32, 96, 256 * Add support for head dim 32, 96, 256 * Update docs * Add update notice * Clippy and format * Conditional compilation for bf16 * Use it in quantized llama * Some review comments * Use set_params! * Remove unused * Remove feature * Fix metal sdpa for v stride * Remove comma * Add the dim method to layout and shape. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/ops.rs190
-rw-r--r--candle-nn/tests/sdpa.rs206
2 files changed, 396 insertions, 0 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index 8a3c19fe..0f35285d 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -964,3 +964,193 @@ impl Module for Identity {
Ok(xs.clone())
}
}
+
+#[allow(dead_code)]
+struct Sdpa {
+ scale: f32,
+ softcapping: f32,
+}
+
+impl candle::CustomOp3 for Sdpa {
+ fn name(&self) -> &'static str {
+ "metal-sdpa"
+ }
+
+ fn cpu_fwd(
+ &self,
+ _s1: &CpuStorage,
+ _l1: &Layout,
+ _s2: &CpuStorage,
+ _l2: &Layout,
+ _s3: &CpuStorage,
+ _l3: &Layout,
+ ) -> Result<(CpuStorage, Shape)> {
+ candle::bail!("SDPA has no cpu impl")
+ }
+
+ #[cfg(feature = "metal")]
+ fn metal_fwd(
+ &self,
+ q: &candle::MetalStorage,
+ q_l: &Layout,
+ k: &candle::MetalStorage,
+ k_l: &Layout,
+ v: &candle::MetalStorage,
+ v_l: &Layout,
+ ) -> Result<(candle::MetalStorage, Shape)> {
+ use candle::backend::BackendStorage;
+ use candle_metal_kernels::SdpaDType;
+
+ let device = q.device();
+
+ let out_dims = vec![q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, v_l.dim(3)?];
+ let elem_count: usize = out_dims.iter().product();
+
+ let output = device.new_buffer(elem_count, q.dtype(), "sdpa_o")?;
+
+ // q,k must have matching emb dim
+ if q_l.dim(D::Minus1)? != k_l.dim(D::Minus1)? {
+ candle::bail!("`q` and `k` last dims must match");
+ }
+
+ // k,v must have matching n kv heads
+ if v_l.dim(D::Minus(3))? != k_l.dim(D::Minus(3))? {
+ candle::bail!("`k` and `v` head dims must match");
+ }
+
+ // n_heads % n_kv_heads == 0; n_heads >= 1, n_kv_heads >= 1.
+ if q_l.dim(D::Minus(3))? % k_l.dim(D::Minus(3))? != 0 {
+ candle::bail!("query `n_heads` must be a multiple of `n_kv_heads`");
+ }
+
+ let k_head = k_l.dim(D::Minus1)?;
+ let q_head = q_l.dim(D::Minus1)?;
+ let q_seq = q_l.dim(2)?;
+
+ let mut implementation_supports_use_case = q_head == k_head;
+ let supported_head_dim =
+ q_head == 32 || q_head == 64 || q_head == 96 || q_head == 128 || q_head == 256;
+
+ const SDPA_FULL_THRESHOLD: usize = 2;
+
+ let supports_sdpa_full =
+ q_seq >= SDPA_FULL_THRESHOLD && supported_head_dim && q_head == k_head;
+ let supports_sdpa_vector = q_seq == 1 && supported_head_dim;
+
+ implementation_supports_use_case &= supports_sdpa_full || supports_sdpa_vector;
+
+ if !supported_head_dim {
+ candle::bail!(
+ "Meta SDPA does not support q head dim {q_head}: q dims {:?}, k dims {:?}, v dims {:?}.",
+ q_l.dims(),
+ k_l.dims(),
+ v_l.dims()
+ );
+ }
+ if !implementation_supports_use_case {
+ candle::bail!(
+ "Meta SDPA does not support q dims {:?}, k dims {:?}, v dims {:?}.",
+ q_l.dims(),
+ k_l.dims(),
+ v_l.dims()
+ );
+ }
+
+ for t in [k.dtype(), v.dtype()] {
+ if q.dtype() != t {
+ candle::bail!("all q, k, v dtypes must match.");
+ }
+ }
+
+ let itype = match q.dtype() {
+ DType::BF16 => SdpaDType::BF16,
+ DType::F16 => SdpaDType::F16,
+ DType::F32 => SdpaDType::F32,
+ other => candle::bail!("unsupported sdpa type {other:?}"),
+ };
+
+ let command_buffer = q.device().command_buffer()?;
+ if supports_sdpa_vector {
+ command_buffer.set_label("vector_attention");
+ candle_metal_kernels::call_sdpa_vector(
+ q.device().device(),
+ &command_buffer,
+ q.device().kernels(),
+ q_l.start_offset(),
+ q_l.dims(),
+ q.buffer(),
+ k_l.start_offset(),
+ k_l.dims(),
+ k_l.stride(),
+ k.buffer(),
+ v_l.start_offset(),
+ v_l.stride(),
+ v.buffer(),
+ &output,
+ self.scale,
+ self.softcapping,
+ itype,
+ )
+ .map_err(candle::Error::wrap)?;
+ } else if supports_sdpa_full {
+ if q_l.dim(2)? != k_l.dim(2)? {
+ candle::bail!(
+ "query and key sequence length must be equal if using full metal sdpa"
+ )
+ }
+
+ command_buffer.set_label("full_attention");
+ candle_metal_kernels::call_sdpa_full(
+ q.device().device(),
+ &command_buffer,
+ q.device().kernels(),
+ q_l.start_offset(),
+ q_l.dims(),
+ q.buffer(),
+ k_l.start_offset(),
+ k.buffer(),
+ v_l.start_offset(),
+ v.buffer(),
+ &output,
+ self.scale,
+ self.softcapping,
+ itype,
+ )
+ .map_err(candle::Error::wrap)?;
+ } else {
+ candle::bail!("must be vector or full sdpa kernel");
+ }
+
+ let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, q.dtype());
+ Ok((newstorage, Shape::from_dims(&out_dims)))
+ }
+}
+
+/// Scaled dot product attention with a fused kernel.
+///
+/// Computes softmax(qk^T*scale)v.
+///
+/// **Inputs shapes:**
+/// - `q`: (bs, qhead, seq, hidden)
+/// - `k`: (bs, kv_head, kv_seq, hidden)
+/// - `k`: (bs, kv_head, kv_seq, v_hidden)
+/// - `scale` is applied before softmax.
+/// - If `softcapping` != 1.0:
+/// - Computation is: softmax(tanh(qk^T*scale/cap)*cap)v
+///
+/// **Output shape:** (bs, qhead, seq, v_hidden)
+///
+/// **Supported head dims:** 32, 64, 96, 128, 256.
+///
+/// ## On Metal:
+/// - If `seq` == 1:
+/// - Use a vectorized kernel
+/// - Supports `seq` != `kv_seq` (cross attn. support)
+/// - Supports GQA when `qhead` is a multiple of `kv_head`
+/// - Otherwise:
+/// - Use an alternate kernel
+/// - Requires `seq` == `kv_seq`
+/// - GQA is not supported (requires `qhead` == `kv_head`)
+pub fn sdpa(q: &Tensor, k: &Tensor, v: &Tensor, scale: f32, softcapping: f32) -> Result<Tensor> {
+ q.apply_op3_no_bwd(k, v, &Sdpa { scale, softcapping })
+}
diff --git a/candle-nn/tests/sdpa.rs b/candle-nn/tests/sdpa.rs
new file mode 100644
index 00000000..67ad3816
--- /dev/null
+++ b/candle-nn/tests/sdpa.rs
@@ -0,0 +1,206 @@
+#[cfg(feature = "metal")]
+mod metal_sdpa_tests {
+ #[test]
+ fn sdpa_full() -> candle::Result<()> {
+ use candle::{DType, Device, Tensor};
+
+ // Force seqlen = 100
+ const BS: usize = 4;
+ const R: usize = 4;
+ const L: usize = 4;
+ const DK: usize = 64;
+ const H: usize = 3;
+ let scale: f64 = f64::from(DK as u32).sqrt().recip();
+
+ let device = Device::new_metal(0)?;
+
+ let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
+ let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
+ let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
+
+ let ground_truth = {
+ let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
+ let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
+ .to_dtype(q.dtype())?;
+ att.matmul(&v.clone())?
+ };
+
+ let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
+
+ assert_eq!(ground_truth.shape(), sdpa_output.shape());
+
+ let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
+ .sum_all()?
+ .to_scalar()?;
+
+ assert!(error <= 0.0005, "{}", error);
+
+ Ok(())
+ }
+
+ #[test]
+ fn sdpa_vector() -> candle::Result<()> {
+ use candle::{DType, Device, Tensor};
+
+ // Allow vectorized, seqlen = 1
+ const BS: usize = 4;
+ const R: usize = 1;
+ const L: usize = 1;
+ const DK: usize = 64;
+ const H: usize = 3;
+ let scale: f64 = f64::from(DK as u32).sqrt().recip();
+
+ let device = Device::new_metal(0)?;
+
+ let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
+ let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
+ let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
+
+ let ground_truth = {
+ let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
+ let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
+ .to_dtype(q.dtype())?;
+ att.matmul(&v.clone())?
+ };
+
+ let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
+
+ assert_eq!(ground_truth.shape(), sdpa_output.shape());
+
+ let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
+ .sum_all()?
+ .to_scalar()?;
+
+ assert!(error <= 0.0001, "{}", error);
+
+ Ok(())
+ }
+
+ #[test]
+ fn sdpa_full_softcapping() -> candle::Result<()> {
+ use candle::{DType, Device, Tensor};
+ use std::ops::{Div, Mul};
+
+ // Allow vectorized, seqlen = 1
+ const BS: usize = 4;
+ const R: usize = 4;
+ const L: usize = 4;
+ const DK: usize = 64;
+ const H: usize = 3;
+ const SOFTCAP: f64 = 50.;
+ let scale: f64 = f64::from(DK as u32).sqrt().recip();
+
+ let device = Device::new_metal(0)?;
+
+ let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
+ let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
+ let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
+
+ let ground_truth = {
+ let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
+ let att = candle_nn::ops::softmax_last_dim(
+ &att.to_dtype(DType::F32)?
+ .div(SOFTCAP)?
+ .tanh()?
+ .mul(SOFTCAP)?,
+ )?
+ .to_dtype(q.dtype())?;
+ att.matmul(&v.clone())?
+ };
+
+ let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?;
+
+ assert_eq!(ground_truth.shape(), sdpa_output.shape());
+
+ let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
+ .sum_all()?
+ .to_scalar()?;
+
+ assert!(error <= 0.0004, "{}", error);
+
+ Ok(())
+ }
+
+ #[test]
+ fn sdpa_vector_softcapping() -> candle::Result<()> {
+ use candle::{DType, Device, Tensor};
+ use std::ops::{Div, Mul};
+
+ // Allow vectorized, seqlen = 1
+ const BS: usize = 4;
+ const R: usize = 1;
+ const L: usize = 1;
+ const DK: usize = 64;
+ const H: usize = 3;
+ const SOFTCAP: f64 = 50.;
+ let scale: f64 = f64::from(DK as u32).sqrt().recip();
+
+ let device = Device::new_metal(0)?;
+
+ let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
+ let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
+ let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
+
+ let ground_truth = {
+ let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
+ let att = candle_nn::ops::softmax_last_dim(
+ &att.to_dtype(DType::F32)?
+ .div(SOFTCAP)?
+ .tanh()?
+ .mul(SOFTCAP)?,
+ )?
+ .to_dtype(q.dtype())?;
+ att.matmul(&v.clone())?
+ };
+
+ let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?;
+
+ assert_eq!(ground_truth.shape(), sdpa_output.shape());
+
+ let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
+ .sum_all()?
+ .to_scalar()?;
+
+ assert!(error <= 0.0001, "{}", error);
+
+ Ok(())
+ }
+
+ #[test]
+ fn sdpa_vector_cross() -> candle::Result<()> {
+ use candle::{DType, Device, Tensor};
+
+ // Allow vectorized, seqlen = 1. Simulat cross attention case where R != L, R = 1
+ const BS: usize = 4;
+ const R: usize = 1;
+ const L: usize = 24;
+ const DK: usize = 64;
+ const H: usize = 3;
+ let scale: f64 = f64::from(DK as u32).sqrt().recip();
+
+ let device = Device::new_metal(0)?;
+
+ let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
+ let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
+ let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
+
+ let ground_truth = {
+ let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
+ let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
+ .to_dtype(q.dtype())?;
+ att.matmul(&v.clone())?
+ };
+
+ let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
+
+ assert_eq!(ground_truth.shape(), sdpa_output.shape());
+
+ let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
+ .sum_all()?
+ .to_scalar()?;
+
+ assert!(error <= 0.0013, "{}", error);
+
+ Ok(())
+ }
+}