diff options
author | Michael Feil <63565275+michaelfeil@users.noreply.github.com> | 2024-12-31 09:41:23 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-12-31 09:41:23 +0100 |
commit | a594ef669ca5ed82c1f19d2230b4b3dc9cb46f43 (patch) | |
tree | 8647429f4c0ae7fddbae84a1936819f0c0172514 | |
parent | 71cd6d55337b1541f602c1afffa6baf6dd75b09c (diff) | |
download | candle-a594ef669ca5ed82c1f19d2230b4b3dc9cb46f43.tar.gz candle-a594ef669ca5ed82c1f19d2230b4b3dc9cb46f43.tar.bz2 candle-a594ef669ca5ed82c1f19d2230b4b3dc9cb46f43.zip |
Flash-Attn upgrade / SoftCap Candle-FlashAttn [2/n] (#2689)
* update flash-attn v1
* restore: hdim224
* add 224 flash_fwd_template
* remove whitespace
* softcap is working, including test and api.
* make softcap test case better
---------
Co-authored-by: laurent <laurent.mazare@gmail.com>
-rw-r--r-- | candle-flash-attn/kernels/flash_api.cu | 16 | ||||
-rw-r--r-- | candle-flash-attn/src/ffi.rs | 2 | ||||
-rw-r--r-- | candle-flash-attn/src/lib.rs | 115 | ||||
-rw-r--r-- | candle-flash-attn/tests/flash_attn_tests.rs | 52 |
4 files changed, 182 insertions, 3 deletions
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu index 4ca41b0a..00933419 100644 --- a/candle-flash-attn/kernels/flash_api.cu +++ b/candle-flash-attn/kernels/flash_api.cu @@ -55,7 +55,9 @@ extern "C" void run_mha( int is_causal, int window_size_left, - int window_size_right + int window_size_right, + + float softcap ) { Flash_fwd_params params; // Reset the parameters @@ -99,8 +101,16 @@ extern "C" void run_mha( params.d_rounded = d_rounded; // Set the different scale values. - params.scale_softmax = softmax_scale; - params.scale_softmax_log2 = softmax_scale * M_LOG2E; + if (softcap > 0.0) { + params.softcap = softmax_scale / softcap; + params.scale_softmax = softcap; + params.scale_softmax_log2 = softcap * M_LOG2E; + } else{ + // Remove potential NaN + params.softcap = 0.0; + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + } params.p_dropout = 1.; // probability to keep params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); diff --git a/candle-flash-attn/src/ffi.rs b/candle-flash-attn/src/ffi.rs index ca65520b..47e54e2a 100644 --- a/candle-flash-attn/src/ffi.rs +++ b/candle-flash-attn/src/ffi.rs @@ -45,6 +45,8 @@ extern "C" { window_size_left: c_int, window_size_right: c_int, + + softcap: f32, ); } diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index f171a986..22a6f1d6 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -11,6 +11,7 @@ pub struct FlashAttn { pub alibi_slopes: Option<Tensor>, pub window_size_left: Option<usize>, pub window_size_right: Option<usize>, + pub softcap: Option<f32>, } fn round_multiple(x: usize, m: usize) -> usize { @@ -201,6 +202,7 @@ impl FlashAttn { /* is_causal */ is_causal, /* window_size_left */ window_size_left, /* window_size_right */ window_size_right, + /* softcap */ self.softcap.unwrap_or(0f32), ) } @@ -271,6 +273,7 @@ pub fn flash_attn( alibi_slopes: None, window_size_left, window_size_right, + softcap: None, }; q.apply_op3(k, v, op) } @@ -308,6 +311,7 @@ pub fn flash_attn_windowed( alibi_slopes: None, window_size_left, window_size_right, + softcap: None, }; q.apply_op3(k, v, op) } @@ -342,6 +346,7 @@ pub fn flash_attn_alibi( alibi_slopes: Some(alibi_slopes.clone()), window_size_left, window_size_right, + softcap: None, }; q.apply_op3(k, v, op) } @@ -381,6 +386,52 @@ pub fn flash_attn_alibi_windowed( alibi_slopes: Some(alibi_slopes.clone()), window_size_left, window_size_right, + softcap: None, + }; + q.apply_op3(k, v, op) +} + +/// Flash-attention v2 layer. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors `k` and `v` with fewer heads +/// than `q`. The number of heads in `k` and `v` must be divisible by the number of heads in `q`. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Optional alibi slopes tensor with shape `(num_heads_q)`. +/// * `softmax_scale` - Scaling factor for the softmax operation. +/// * `window_size_left` - Optional limit on left attention to value tokens. +/// * `window_size_right` - Optional limit on right attention to value tokens. +/// * `softcap` - Gemma style softcap the attention logits before the softmax. +/// +/// # Causal Mask +/// +/// Setting `window_size_left=None` and `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T`. +/// +/// # Returns +/// +/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`. +pub fn flash_attn_alibi_windowed_softcap( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: Option<&Tensor>, + softmax_scale: f32, + window_size_left: Option<usize>, + window_size_right: Option<usize>, + softcap: f32, +) -> Result<Tensor> { + let op = FlashAttn { + softmax_scale, + alibi_slopes: alibi_slopes.cloned(), + window_size_left, + window_size_right, + softcap: Some(softcap), }; q.apply_op3(k, v, op) } @@ -394,6 +445,7 @@ struct FlashAttnVarLen { pub alibi_slopes: Option<Tensor>, pub window_size_left: Option<usize>, pub window_size_right: Option<usize>, + pub softcap: Option<f32>, } impl FlashAttnVarLen { @@ -613,6 +665,7 @@ impl FlashAttnVarLen { /* is_causal */ is_causal, /* window_size_left */ window_size_left, /* window_size_right */ window_size_right, + /* softcap */ self.softcap.unwrap_or(0.0), ) } @@ -699,6 +752,7 @@ pub fn flash_attn_varlen( alibi_slopes: None, window_size_left, window_size_right, + softcap: None, }; q.apply_op3(k, v, op) } @@ -752,6 +806,7 @@ pub fn flash_attn_varlen_windowed( alibi_slopes: None, window_size_left, window_size_right, + softcap: None, }; q.apply_op3(k, v, op) } @@ -802,6 +857,7 @@ pub fn flash_attn_varlen_alibi( alibi_slopes: Some(alibi_slopes.clone()), window_size_left, window_size_right, + softcap: None, }; q.apply_op3(k, v, op) } @@ -857,6 +913,65 @@ pub fn flash_attn_varlen_alibi_windowed( alibi_slopes: Some(alibi_slopes.clone()), window_size_left, window_size_right, + softcap: None, + }; + q.apply_op3(k, v, op) +} + +#[allow(clippy::too_many_arguments)] +/// Flash-attention v2 layer with variable-length batching. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Option, alibi slopes tensor with shape `(num_heads_q)`. +/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q. +/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v. +/// * `max_seqlen_q` - The maximum query sequence length for q in the batch. +/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch. +/// * `window_size_left` - Option, limit left attention to value tokens. +/// * `window_size_right` - Option, limit right attention to value tokens. +/// * `softcap` - Gemma style softcap the attention logits before the softmax. +/// +/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`, +/// `seqlen_1 + seqlen_2`, etc. +/// +/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +pub fn flash_attn_varlen_alibi_windowed_softcap( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: Option<&Tensor>, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + window_size_left: Option<usize>, + window_size_right: Option<usize>, + softcap: f32, +) -> Result<Tensor> { + let op = FlashAttnVarLen { + softmax_scale, + max_seqlen_q, + max_seqlen_k, + seqlens_q: seqlens_q.clone(), + seqlens_k: seqlens_k.clone(), + alibi_slopes: alibi_slopes.cloned(), + window_size_left, + window_size_right, + softcap: Some(softcap), }; q.apply_op3(k, v, op) } diff --git a/candle-flash-attn/tests/flash_attn_tests.rs b/candle-flash-attn/tests/flash_attn_tests.rs index 250added..e3058611 100644 --- a/candle-flash-attn/tests/flash_attn_tests.rs +++ b/candle-flash-attn/tests/flash_attn_tests.rs @@ -27,6 +27,20 @@ fn fa_acausal(q: &Tensor, k: &Tensor, v: &Tensor, softmax_scale: f32) -> Result< Ok(output) } +fn fa_acausal_softcap(q: &Tensor, k: &Tensor, v: &Tensor, softcap: f32) -> Result<Tensor> { + let in_dtype = q.dtype(); + let q = q.to_dtype(DType::F32)?; + let k = k.to_dtype(DType::F32)?; + let v = v.to_dtype(DType::F32)?; + // let att = (q.matmul(&k.t()?)? * softmax_scale as f64)?; + let att = q.matmul(&k.t()?)?; + let att = (softcap as f64 * ((att / softcap as f64)?.tanh())?)?; + let att = candle_nn::ops::softmax(&att, D::Minus1)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + let output = att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?; + Ok(output) +} + #[test] fn flash_attn_acausal() -> Result<()> { let device = Device::new_cuda(0)?; @@ -90,6 +104,44 @@ fn flash_attn_acausal() -> Result<()> { } #[test] +fn flash_attn_acausal_softcap() -> Result<()> { + let device = Device::new_cuda(0)?; + let q = Tensor::arange(0u32, 3 * 5 * 8, &device)? + .to_dtype(DType::F16)? + .reshape((1, 3, 5, 8))?; + let k = (&q / 40.)?; + let v = (&q / 50.)?; + let q = (&q / 30.)?; + let softcap = 5.0f32; + + let ys1 = fa_acausal_softcap(&q, &k, &v, softcap.clone())?; + let ys1 = ys1.i(0)?.to_dtype(DType::F32)?; + let ys2 = { + let q = q.transpose(1, 2)?; + let k = k.transpose(1, 2)?; + let v = v.transpose(1, 2)?; + candle_flash_attn::flash_attn_alibi_windowed_softcap( + &q, + &k, + &v, + None, // alibi_slopes // + 1.0, // softmax // + None, // window_size_left // + None, // window_size_right // + softcap.clone(), // softcap // + )? + .transpose(1, 2)? + }; + let ys2 = ys2.i(0)?.to_dtype(DType::F32)?; + let diff = ys1.sub(&ys2)?.abs()?.flatten_all()?.max(0)?; + + assert_eq!(ys1.dims(), &[3, 5, 8]); + assert_eq!(ys2.dims(), &[3, 5, 8]); + assert!(diff.to_vec0::<f32>()?.abs() < 1e-3); + Ok(()) +} + +#[test] fn flash_attn_varlen() -> Result<()> { let device = Device::new_cuda(0)?; let q = Tensor::arange(0u32, 48, &device)? |