diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-26 20:56:00 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-26 20:56:00 +0100 |
commit | 4f92420132d831e5d344f974c263c9f341e50906 (patch) | |
tree | a43610e4869d0d814248e6cbcdd4242dc85eadf9 /candle-flash-attn | |
parent | ded197497c48485aa6b00c45318db7cf7e7cdf96 (diff) | |
download | candle-4f92420132d831e5d344f974c263c9f341e50906.tar.gz candle-4f92420132d831e5d344f974c263c9f341e50906.tar.bz2 candle-4f92420132d831e5d344f974c263c9f341e50906.zip |
Add some flash attn test (#253)
* Add some flash-attn test.
* Add the cpu test.
* Fail when the head is not a multiple of 8.
* Polish the flash attention test.
Diffstat (limited to 'candle-flash-attn')
-rw-r--r-- | candle-flash-attn/Cargo.toml | 3 | ||||
-rw-r--r-- | candle-flash-attn/build.rs | 8 | ||||
-rw-r--r-- | candle-flash-attn/src/lib.rs | 34 | ||||
-rw-r--r-- | candle-flash-attn/tests/flash_attn_tests.rs | 90 |
4 files changed, 123 insertions, 12 deletions
diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml index 9d21cf4a..013da854 100644 --- a/candle-flash-attn/Cargo.toml +++ b/candle-flash-attn/Cargo.toml @@ -18,3 +18,6 @@ half = { version = "2.3.1", features = ["num-traits"] } anyhow = { version = "1", features = ["backtrace"] } num_cpus = "1.15.0" rayon = "1.7.0" + +[dev-dependencies] +anyhow = { version = "1", features = ["backtrace"] } diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs index 7a4588a4..d52ab92f 100644 --- a/candle-flash-attn/build.rs +++ b/candle-flash-attn/build.rs @@ -6,7 +6,7 @@ use rayon::prelude::*; use std::path::PathBuf; use std::str::FromStr; -const KERNEL_FILES: [&'static str; 9] = [ +const KERNEL_FILES: [&str; 9] = [ "flash_api.cu", "flash_fwd_hdim128_fp16_sm80.cu", "flash_fwd_hdim160_fp16_sm80.cu", @@ -52,7 +52,11 @@ fn main() -> Result<()> { println!("cargo:rerun-if-changed=kernels/static_switch.h"); let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?); let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") { - Err(_) => out_dir.clone(), + Err(_) => + { + #[allow(clippy::redundant_clone)] + out_dir.clone() + } Ok(build_dir) => PathBuf::from(build_dir), }; set_cuda_include_dir()?; diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index 0123543b..efdefee9 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -6,7 +6,7 @@ use candle::cuda_backend::WrapErr; use candle::{CpuStorage, Layout, Result, Shape, Tensor}; use half::f16; -pub struct FlashHdim32Sm80 { +pub struct FlashAttn { pub softmax_scale: f32, pub causal: bool, } @@ -15,7 +15,7 @@ fn round_multiple(x: usize, m: usize) -> usize { (x + m - 1) / m * m } -impl candle::CustomOp3 for FlashHdim32Sm80 { +impl candle::CustomOp3 for FlashAttn { fn name(&self) -> &'static str { "flash-hdim32-sm80" } @@ -87,6 +87,10 @@ impl candle::CustomOp3 for FlashHdim32Sm80 { if head_size_og > 256 { candle::bail!("only supports head dimension at most 256 (got {head_size_og})") } + if head_size_og % 8 != 0 { + // TODO: Handle head sizes that are not a multiple of 8 via some padding. + candle::bail!("only supports head sizes that are a multiple of 8 (got {head_size_og})") + } if num_heads % num_heads_k != 0 { candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}") } @@ -145,6 +149,19 @@ impl candle::CustomOp3 for FlashHdim32Sm80 { } } +/// Flash-attention v2 layer using flash-attention. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// +/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`. pub fn flash_attn( q: &Tensor, k: &Tensor, @@ -152,12 +169,9 @@ pub fn flash_attn( softmax_scale: f32, causal: bool, ) -> Result<Tensor> { - q.custom_op3( - k, - v, - FlashHdim32Sm80 { - softmax_scale, - causal, - }, - ) + let op = FlashAttn { + softmax_scale, + causal, + }; + q.custom_op3(k, v, op) } diff --git a/candle-flash-attn/tests/flash_attn_tests.rs b/candle-flash-attn/tests/flash_attn_tests.rs new file mode 100644 index 00000000..c6780659 --- /dev/null +++ b/candle-flash-attn/tests/flash_attn_tests.rs @@ -0,0 +1,90 @@ +use anyhow::Result; +use candle::{DType, Device, IndexOp, Tensor, D}; + +fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> { + let b = 10f32.powi(digits); + let t = t.to_vec3::<f32>()?; + let t = t + .iter() + .map(|t| { + t.iter() + .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) + .collect() + }) + .collect(); + Ok(t) +} + +fn fa_acausal(q: &Tensor, k: &Tensor, v: &Tensor, softmax_scale: f32) -> Result<Tensor> { + let in_dtype = q.dtype(); + let q = q.to_dtype(DType::F32)?; + let k = k.to_dtype(DType::F32)?; + let v = v.to_dtype(DType::F32)?; + let att = (q.matmul(&k.t()?)? * softmax_scale as f64)?; + let att = att.softmax(D::Minus1)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + let output = att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?; + Ok(output) +} + +#[test] +fn flash_attn_acausal() -> Result<()> { + let device = Device::new_cuda(0)?; + let q = Tensor::arange(0u32, 48, &device)? + .to_dtype(DType::F16)? + .reshape((1, 3, 2, 8))?; + let k = (&q / 40.)?; + let v = (&q / 50.)?; + let q = (&q / 30.)?; + + let ys1 = fa_acausal(&q, &k, &v, 0.5)?; + let ys1 = ys1.i(0)?.to_dtype(DType::F32)?; + let ys2 = { + let q = q.transpose(1, 2)?; + let k = k.transpose(1, 2)?; + let v = v.transpose(1, 2)?; + candle_flash_attn::flash_attn(&q, &k, &v, 0.5, false)?.transpose(1, 2)? + }; + let ys2 = ys2.i(0)?.to_dtype(DType::F32)?; + let diff = ys1.sub(&ys2)?.abs()?.flatten_all()?.max(0)?; + + assert_eq!(ys1.dims(), &[3, 2, 8]); + assert_eq!( + to_vec3_round(ys1, 4)?, + &[ + [ + [0.0837, 0.1038, 0.1238, 0.1438, 0.1637, 0.1837, 0.2037, 0.2238], + [0.0922, 0.1122, 0.1322, 0.1522, 0.1721, 0.1921, 0.2122, 0.2322] + ], + [ + [0.4204, 0.4404, 0.4604, 0.4805, 0.5005, 0.5205, 0.5405, 0.5605], + [0.428, 0.448, 0.468, 0.488, 0.5083, 0.5283, 0.5483, 0.5684] + ], + [ + [0.7554, 0.7754, 0.7954, 0.8154, 0.8354, 0.8555, 0.8755, 0.8955], + [0.7622, 0.7822, 0.8022, 0.8223, 0.8423, 0.8623, 0.8823, 0.9023] + ] + ] + ); + + assert_eq!(ys2.dims(), &[3, 2, 8]); + assert_eq!( + to_vec3_round(ys2, 4)?, + &[ + [ + [0.0837, 0.1038, 0.1238, 0.1438, 0.1637, 0.1837, 0.2037, 0.2238], + [0.0922, 0.1122, 0.1322, 0.1522, 0.1721, 0.1921, 0.2122, 0.2322] + ], + [ + [0.4204, 0.4404, 0.4604, 0.4805, 0.5005, 0.5205, 0.5405, 0.5605], + [0.428, 0.448, 0.468, 0.488, 0.5083, 0.5283, 0.5483, 0.5684] + ], + [ + [0.7554, 0.7754, 0.7954, 0.8154, 0.8354, 0.8555, 0.8755, 0.8955], + [0.7622, 0.7822, 0.8022, 0.8223, 0.8423, 0.8623, 0.8823, 0.9023] + ] + ] + ); + assert!(diff.to_vec0::<f32>()?.abs() < 1e-5); + Ok(()) +} |