diff options
author | Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> | 2024-11-05 03:28:00 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-11-05 09:28:00 +0100 |
commit | e2b6b367fa852ed30ac532f8d77cd8479c7ed092 (patch) | |
tree | 41321e646a0ee9abef88122b202bd940240ecae6 /candle-nn/tests/sdpa.rs | |
parent | 6454597943599dd6df787a0d5f2446c5724d850a (diff) | |
download | candle-e2b6b367fa852ed30ac532f8d77cd8479c7ed092.tar.gz candle-e2b6b367fa852ed30ac532f8d77cd8479c7ed092.tar.bz2 candle-e2b6b367fa852ed30ac532f8d77cd8479c7ed092.zip |
Add some fast Metal MLX SDPA kernels (#2584)
* Add some fast Metal MLX SDPA kernels (#32)
* Sketch the sdpa kernel
* Add full sdpa kernel,
* Add test
* Add vectorized kernel for decoding
* Update tests
* Add some docs
* Fix sdpa_vector names
* Add softcapping for vectorized sdpa
* Add softcapping for full sdpa
* Add support for head dim 32, 96, 256
* Add support for head dim 32, 96, 256
* Update docs
* Add update notice
* Clippy and format
* Conditional compilation for bf16
* Use it in quantized llama
* Some review comments
* Use set_params!
* Remove unused
* Remove feature
* Fix metal sdpa for v stride
* Remove comma
* Add the dim method to layout and shape.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-nn/tests/sdpa.rs')
-rw-r--r-- | candle-nn/tests/sdpa.rs | 206 |
1 files changed, 206 insertions, 0 deletions
diff --git a/candle-nn/tests/sdpa.rs b/candle-nn/tests/sdpa.rs new file mode 100644 index 00000000..67ad3816 --- /dev/null +++ b/candle-nn/tests/sdpa.rs @@ -0,0 +1,206 @@ +#[cfg(feature = "metal")] +mod metal_sdpa_tests { + #[test] + fn sdpa_full() -> candle::Result<()> { + use candle::{DType, Device, Tensor}; + + // Force seqlen = 100 + const BS: usize = 4; + const R: usize = 4; + const L: usize = 4; + const DK: usize = 64; + const H: usize = 3; + let scale: f64 = f64::from(DK as u32).sqrt().recip(); + + let device = Device::new_metal(0)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + + let ground_truth = { + let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; + let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)? + .to_dtype(q.dtype())?; + att.matmul(&v.clone())? + }; + + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; + + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error <= 0.0005, "{}", error); + + Ok(()) + } + + #[test] + fn sdpa_vector() -> candle::Result<()> { + use candle::{DType, Device, Tensor}; + + // Allow vectorized, seqlen = 1 + const BS: usize = 4; + const R: usize = 1; + const L: usize = 1; + const DK: usize = 64; + const H: usize = 3; + let scale: f64 = f64::from(DK as u32).sqrt().recip(); + + let device = Device::new_metal(0)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + + let ground_truth = { + let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; + let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)? + .to_dtype(q.dtype())?; + att.matmul(&v.clone())? + }; + + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; + + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error <= 0.0001, "{}", error); + + Ok(()) + } + + #[test] + fn sdpa_full_softcapping() -> candle::Result<()> { + use candle::{DType, Device, Tensor}; + use std::ops::{Div, Mul}; + + // Allow vectorized, seqlen = 1 + const BS: usize = 4; + const R: usize = 4; + const L: usize = 4; + const DK: usize = 64; + const H: usize = 3; + const SOFTCAP: f64 = 50.; + let scale: f64 = f64::from(DK as u32).sqrt().recip(); + + let device = Device::new_metal(0)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + + let ground_truth = { + let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; + let att = candle_nn::ops::softmax_last_dim( + &att.to_dtype(DType::F32)? + .div(SOFTCAP)? + .tanh()? + .mul(SOFTCAP)?, + )? + .to_dtype(q.dtype())?; + att.matmul(&v.clone())? + }; + + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?; + + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error <= 0.0004, "{}", error); + + Ok(()) + } + + #[test] + fn sdpa_vector_softcapping() -> candle::Result<()> { + use candle::{DType, Device, Tensor}; + use std::ops::{Div, Mul}; + + // Allow vectorized, seqlen = 1 + const BS: usize = 4; + const R: usize = 1; + const L: usize = 1; + const DK: usize = 64; + const H: usize = 3; + const SOFTCAP: f64 = 50.; + let scale: f64 = f64::from(DK as u32).sqrt().recip(); + + let device = Device::new_metal(0)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + + let ground_truth = { + let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; + let att = candle_nn::ops::softmax_last_dim( + &att.to_dtype(DType::F32)? + .div(SOFTCAP)? + .tanh()? + .mul(SOFTCAP)?, + )? + .to_dtype(q.dtype())?; + att.matmul(&v.clone())? + }; + + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?; + + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error <= 0.0001, "{}", error); + + Ok(()) + } + + #[test] + fn sdpa_vector_cross() -> candle::Result<()> { + use candle::{DType, Device, Tensor}; + + // Allow vectorized, seqlen = 1. Simulat cross attention case where R != L, R = 1 + const BS: usize = 4; + const R: usize = 1; + const L: usize = 24; + const DK: usize = 64; + const H: usize = 3; + let scale: f64 = f64::from(DK as u32).sqrt().recip(); + + let device = Device::new_metal(0)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + + let ground_truth = { + let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; + let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)? + .to_dtype(q.dtype())?; + att.matmul(&v.clone())? + }; + + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; + + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error <= 0.0013, "{}", error); + + Ok(()) + } +} |