summaryrefslogtreecommitdiff
path: root/candle-nn/tests/sdpa.rs
diff options
context:
space:
mode:
authorEric Buehler <65165915+EricLBuehler@users.noreply.github.com>2024-11-05 03:28:00 -0500
committerGitHub <noreply@github.com>2024-11-05 09:28:00 +0100
commite2b6b367fa852ed30ac532f8d77cd8479c7ed092 (patch)
tree41321e646a0ee9abef88122b202bd940240ecae6 /candle-nn/tests/sdpa.rs
parent6454597943599dd6df787a0d5f2446c5724d850a (diff)
downloadcandle-e2b6b367fa852ed30ac532f8d77cd8479c7ed092.tar.gz
candle-e2b6b367fa852ed30ac532f8d77cd8479c7ed092.tar.bz2
candle-e2b6b367fa852ed30ac532f8d77cd8479c7ed092.zip
Add some fast Metal MLX SDPA kernels (#2584)
* Add some fast Metal MLX SDPA kernels (#32) * Sketch the sdpa kernel * Add full sdpa kernel, * Add test * Add vectorized kernel for decoding * Update tests * Add some docs * Fix sdpa_vector names * Add softcapping for vectorized sdpa * Add softcapping for full sdpa * Add support for head dim 32, 96, 256 * Add support for head dim 32, 96, 256 * Update docs * Add update notice * Clippy and format * Conditional compilation for bf16 * Use it in quantized llama * Some review comments * Use set_params! * Remove unused * Remove feature * Fix metal sdpa for v stride * Remove comma * Add the dim method to layout and shape. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-nn/tests/sdpa.rs')
-rw-r--r--candle-nn/tests/sdpa.rs206
1 files changed, 206 insertions, 0 deletions
diff --git a/candle-nn/tests/sdpa.rs b/candle-nn/tests/sdpa.rs
new file mode 100644
index 00000000..67ad3816
--- /dev/null
+++ b/candle-nn/tests/sdpa.rs
@@ -0,0 +1,206 @@
+#[cfg(feature = "metal")]
+mod metal_sdpa_tests {
+ #[test]
+ fn sdpa_full() -> candle::Result<()> {
+ use candle::{DType, Device, Tensor};
+
+ // Force seqlen = 100
+ const BS: usize = 4;
+ const R: usize = 4;
+ const L: usize = 4;
+ const DK: usize = 64;
+ const H: usize = 3;
+ let scale: f64 = f64::from(DK as u32).sqrt().recip();
+
+ let device = Device::new_metal(0)?;
+
+ let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
+ let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
+ let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
+
+ let ground_truth = {
+ let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
+ let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
+ .to_dtype(q.dtype())?;
+ att.matmul(&v.clone())?
+ };
+
+ let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
+
+ assert_eq!(ground_truth.shape(), sdpa_output.shape());
+
+ let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
+ .sum_all()?
+ .to_scalar()?;
+
+ assert!(error <= 0.0005, "{}", error);
+
+ Ok(())
+ }
+
+ #[test]
+ fn sdpa_vector() -> candle::Result<()> {
+ use candle::{DType, Device, Tensor};
+
+ // Allow vectorized, seqlen = 1
+ const BS: usize = 4;
+ const R: usize = 1;
+ const L: usize = 1;
+ const DK: usize = 64;
+ const H: usize = 3;
+ let scale: f64 = f64::from(DK as u32).sqrt().recip();
+
+ let device = Device::new_metal(0)?;
+
+ let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
+ let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
+ let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
+
+ let ground_truth = {
+ let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
+ let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
+ .to_dtype(q.dtype())?;
+ att.matmul(&v.clone())?
+ };
+
+ let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
+
+ assert_eq!(ground_truth.shape(), sdpa_output.shape());
+
+ let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
+ .sum_all()?
+ .to_scalar()?;
+
+ assert!(error <= 0.0001, "{}", error);
+
+ Ok(())
+ }
+
+ #[test]
+ fn sdpa_full_softcapping() -> candle::Result<()> {
+ use candle::{DType, Device, Tensor};
+ use std::ops::{Div, Mul};
+
+ // Allow vectorized, seqlen = 1
+ const BS: usize = 4;
+ const R: usize = 4;
+ const L: usize = 4;
+ const DK: usize = 64;
+ const H: usize = 3;
+ const SOFTCAP: f64 = 50.;
+ let scale: f64 = f64::from(DK as u32).sqrt().recip();
+
+ let device = Device::new_metal(0)?;
+
+ let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
+ let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
+ let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
+
+ let ground_truth = {
+ let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
+ let att = candle_nn::ops::softmax_last_dim(
+ &att.to_dtype(DType::F32)?
+ .div(SOFTCAP)?
+ .tanh()?
+ .mul(SOFTCAP)?,
+ )?
+ .to_dtype(q.dtype())?;
+ att.matmul(&v.clone())?
+ };
+
+ let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?;
+
+ assert_eq!(ground_truth.shape(), sdpa_output.shape());
+
+ let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
+ .sum_all()?
+ .to_scalar()?;
+
+ assert!(error <= 0.0004, "{}", error);
+
+ Ok(())
+ }
+
+ #[test]
+ fn sdpa_vector_softcapping() -> candle::Result<()> {
+ use candle::{DType, Device, Tensor};
+ use std::ops::{Div, Mul};
+
+ // Allow vectorized, seqlen = 1
+ const BS: usize = 4;
+ const R: usize = 1;
+ const L: usize = 1;
+ const DK: usize = 64;
+ const H: usize = 3;
+ const SOFTCAP: f64 = 50.;
+ let scale: f64 = f64::from(DK as u32).sqrt().recip();
+
+ let device = Device::new_metal(0)?;
+
+ let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
+ let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
+ let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
+
+ let ground_truth = {
+ let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
+ let att = candle_nn::ops::softmax_last_dim(
+ &att.to_dtype(DType::F32)?
+ .div(SOFTCAP)?
+ .tanh()?
+ .mul(SOFTCAP)?,
+ )?
+ .to_dtype(q.dtype())?;
+ att.matmul(&v.clone())?
+ };
+
+ let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?;
+
+ assert_eq!(ground_truth.shape(), sdpa_output.shape());
+
+ let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
+ .sum_all()?
+ .to_scalar()?;
+
+ assert!(error <= 0.0001, "{}", error);
+
+ Ok(())
+ }
+
+ #[test]
+ fn sdpa_vector_cross() -> candle::Result<()> {
+ use candle::{DType, Device, Tensor};
+
+ // Allow vectorized, seqlen = 1. Simulat cross attention case where R != L, R = 1
+ const BS: usize = 4;
+ const R: usize = 1;
+ const L: usize = 24;
+ const DK: usize = 64;
+ const H: usize = 3;
+ let scale: f64 = f64::from(DK as u32).sqrt().recip();
+
+ let device = Device::new_metal(0)?;
+
+ let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
+ let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
+ let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
+
+ let ground_truth = {
+ let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
+ let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
+ .to_dtype(q.dtype())?;
+ att.matmul(&v.clone())?
+ };
+
+ let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
+
+ assert_eq!(ground_truth.shape(), sdpa_output.shape());
+
+ let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
+ .sum_all()?
+ .to_scalar()?;
+
+ assert!(error <= 0.0013, "{}", error);
+
+ Ok(())
+ }
+}