diff options
author | OlivierDehaene <Olivier.dehaene@gmail.com> | 2024-01-05 18:28:55 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-05 18:28:55 +0100 |
commit | 8d1a57c9a0465b201e4e9e410e2b8fcde37b35f7 (patch) | |
tree | 289466c9df7a7f21ec1e574cd6cfd7b957998a08 /candle-flash-attn/src/lib.rs | |
parent | 3a7304cb0dbdf8ceeab8a4f5cf9b8e7ced822e20 (diff) | |
download | candle-8d1a57c9a0465b201e4e9e410e2b8fcde37b35f7.tar.gz candle-8d1a57c9a0465b201e4e9e410e2b8fcde37b35f7.tar.bz2 candle-8d1a57c9a0465b201e4e9e410e2b8fcde37b35f7.zip |
chore: update flash attention kernels (#1518)
* chore: update flash attention kernels
* fmt
* remove unused kernels
* force f32
* correct stride
Diffstat (limited to 'candle-flash-attn/src/lib.rs')
-rw-r--r-- | candle-flash-attn/src/lib.rs | 434 |
1 files changed, 421 insertions, 13 deletions
diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index 3395bd0d..21a06b5e 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -3,12 +3,14 @@ mod ffi; use candle::backend::BackendStorage; use candle::cuda_backend::cudarc::driver::DevicePtr; use candle::cuda_backend::WrapErr; -use candle::{CpuStorage, Layout, Result, Shape, Tensor}; +use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor}; use half::{bf16, f16}; pub struct FlashAttn { pub softmax_scale: f32, - pub causal: bool, + pub alibi_slopes: Option<Tensor>, + pub window_size_left: Option<usize>, + pub window_size_right: Option<usize>, } fn round_multiple(x: usize, m: usize) -> usize { @@ -85,6 +87,51 @@ impl FlashAttn { candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}") } + let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes { + if alibi_slopes.dtype() != DType::F32 { + candle::bail!( + "DType mismatch alibi_slopes {:?}, expected {:?}", + alibi_slopes.dtype(), + DType::F32 + ); + } + + let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout(); + + if num_heads != alibi_slopes_layout.shape().dims1()? { + candle::bail!( + "shape mismatch alibi_slopes {:?}, expected {:?}", + alibi_slopes_layout.shape(), + (num_heads) + ); + } + + let alibi_slopes = match &*alibi_slopes { + candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?, + _ => candle::bail!("alibi_slopes must be a cuda tensor"), + }; + + let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..); + + *alibi_slopes.device_ptr() as *const core::ffi::c_void + } else { + std::ptr::null() + }; + + // if window_size_left > self.max_seqlen_k or None => -1 + let mut window_size_left = self + .window_size_left + .filter(|v| v <= &seqlen_k) + .map(|v| v as i32) + .unwrap_or(-1); + + // if window_size_right > self.max_seqlen_k or None => -1 + let mut window_size_right = self + .window_size_right + .filter(|v| v <= &seqlen_k) + .map(|v| v as i32) + .unwrap_or(-1); + let head_size = round_multiple(head_size_og, 8); let head_size_rounded = round_multiple(head_size, 32); let seqlen_q_rounded = round_multiple(seqlen_q, 128); @@ -94,9 +141,22 @@ impl FlashAttn { let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?; let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?; - let causal = if self.causal { 1 } else { 0 }; let is_bf16 = if is_bf16 { 1 } else { 0 }; + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + let is_causal = if window_size_left < 0 && window_size_right == 0 { + 1 + } else { + 0 + }; + if window_size_left < 0 && window_size_right >= 0 { + window_size_left = seqlen_k as i32; + } + if window_size_left >= 0 && window_size_right < 0 { + window_size_right = seqlen_k as i32; + } + unsafe { let q_ptr = *q.device_ptr() as *const core::ffi::c_void; let k_ptr = *k.device_ptr() as *const core::ffi::c_void; @@ -109,12 +169,14 @@ impl FlashAttn { v_ptr, dst_ptr, softmax_lse_ptr, + /* alibi_slopes_ptr */ alibi_slopes_ptr, /* cu_seqlens_q_ptr */ std::ptr::null(), /* cu_seqlens_k_ptr */ std::ptr::null(), /* q_batch_stride */ q_stride[0] as u32, /* k_batch_stride */ k_stride[0] as u32, /* v_batch_stride */ v_stride[0] as u32, /* o_batch_stride */ o_stride[0] as u32, + /* alibi_slopes_batch_stride */ 0, /* q_row_stride */ q_stride[q_rank - 3] as u32, /* k_row_stride */ k_stride[k_rank - 3] as u32, /* v_row_stride */ v_stride[v_rank - 3] as u32, @@ -133,8 +195,10 @@ impl FlashAttn { /* seqlen_k */ seqlen_k as u32, /* seqlen_q_rounded */ seqlen_q_rounded as u32, /* seqlen_k_rounded */ seqlen_k_rounded as u32, - /* is_causal */ causal, /* is_bf16 */ is_bf16, + /* is_causal */ is_causal, + /* window_size_left */ window_size_left, + /* window_size_right */ window_size_right, ) } @@ -197,20 +261,137 @@ pub fn flash_attn( softmax_scale: f32, causal: bool, ) -> Result<Tensor> { + let window_size_left = None; + let window_size_right = if causal { Some(0) } else { None }; + let op = FlashAttn { softmax_scale, - causal, + alibi_slopes: None, + window_size_left, + window_size_right, }; q.apply_op3(k, v, op) } -struct FlashAttnVarLen { +/// Flash-attention v2 layer. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `window_size_left` - Limit left attention to value tokens. +/// * `window_size_right` - Limit right attention to value tokens. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +/// +/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`. +pub fn flash_attn_windowed( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + window_size_left: Option<usize>, + window_size_right: Option<usize>, +) -> Result<Tensor> { + let op = FlashAttn { + softmax_scale, + alibi_slopes: None, + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +/// Flash-attention v2 layer. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`. +/// +/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`. +pub fn flash_attn_alibi( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, softmax_scale: f32, causal: bool, - max_seqlen_q: usize, - max_seqlen_k: usize, - seqlens_q: Tensor, - seqlens_k: Tensor, +) -> Result<Tensor> { + let window_size_left = None; + let window_size_right = if causal { Some(0) } else { None }; + + let op = FlashAttn { + softmax_scale, + alibi_slopes: Some(alibi_slopes.clone()), + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +/// Flash-attention v2 layer. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`. +/// * `window_size_left` - Limit left attention to value tokens. +/// * `window_size_right` - Limit right attention to value tokens. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +/// +/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`. +pub fn flash_attn_alibi_windowed( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + softmax_scale: f32, + window_size_left: Option<usize>, + window_size_right: Option<usize>, +) -> Result<Tensor> { + let op = FlashAttn { + softmax_scale, + alibi_slopes: Some(alibi_slopes.clone()), + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +struct FlashAttnVarLen { + pub softmax_scale: f32, + pub max_seqlen_q: usize, + pub max_seqlen_k: usize, + pub seqlens_q: Tensor, + pub seqlens_k: Tensor, + pub alibi_slopes: Option<Tensor>, + pub window_size_left: Option<usize>, + pub window_size_right: Option<usize>, } impl FlashAttnVarLen { @@ -311,7 +492,54 @@ impl FlashAttnVarLen { if nseqlens_k != nseqlens_q { candle::bail!("seqlens_q and seqlens_k should have the same number of elements {nseqlens_q} <> {nseqlens_k}") } + let batch_size = nseqlens_q - 1; + + let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes { + if alibi_slopes.dtype() != DType::F32 { + candle::bail!( + "DType mismatch alibi_slopes {:?}, expected {:?}", + alibi_slopes.dtype(), + DType::F32 + ); + } + + let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout(); + + if num_heads != alibi_slopes_layout.shape().dims1()? { + candle::bail!( + "shape mismatch alibi_slopes {:?}, expected {:?}", + alibi_slopes_layout.shape(), + (num_heads) + ); + } + + let alibi_slopes = match &*alibi_slopes { + candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?, + _ => candle::bail!("alibi_slopes must be a cuda tensor"), + }; + + let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..); + + *alibi_slopes.device_ptr() as *const core::ffi::c_void + } else { + std::ptr::null() + }; + + // if window_size_left > self.max_seqlen_k or None => -1 + let mut window_size_left = self + .window_size_left + .filter(|v| v <= &self.max_seqlen_k) + .map(|v| v as i32) + .unwrap_or(-1); + + // if window_size_right > self.max_seqlen_k or None => -1 + let mut window_size_right = self + .window_size_right + .filter(|v| v <= &self.max_seqlen_k) + .map(|v| v as i32) + .unwrap_or(-1); + let head_size = round_multiple(head_size_og, 8); let head_size_rounded = round_multiple(head_size, 32); let seqlen_q_rounded = round_multiple(self.max_seqlen_q, 128); @@ -323,9 +551,22 @@ impl FlashAttnVarLen { .alloc_zeros::<f32>(batch_size * num_heads * self.max_seqlen_q) .w()?; - let causal = if self.causal { 1 } else { 0 }; let is_bf16 = if is_bf16 { 1 } else { 0 }; + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + let is_causal = if window_size_left < 0 && window_size_right == 0 { + 1 + } else { + 0 + }; + if window_size_left < 0 && window_size_right >= 0 { + window_size_left = self.max_seqlen_k as i32; + } + if window_size_left >= 0 && window_size_right < 0 { + window_size_right = self.max_seqlen_k as i32; + } + unsafe { let q_ptr = *q.device_ptr() as *const core::ffi::c_void; let k_ptr = *k.device_ptr() as *const core::ffi::c_void; @@ -340,12 +581,14 @@ impl FlashAttnVarLen { v_ptr, dst_ptr, softmax_lse_ptr, + /* alibi_slopes_ptr */ alibi_slopes_ptr, /* cu_seqlens_q_ptr */ seqlens_q_ptr, /* cu_seqlens_k_ptr */ seqlens_k_ptr, /* q_batch_stride */ 0, /* k_batch_stride */ 0, /* v_batch_stride */ 0, /* o_batch_stride */ 0, + /* alibi_slopes_batch_stride */ 0, /* q_row_stride */ q_stride[q_rank - 3] as u32, /* k_row_stride */ k_stride[k_rank - 3] as u32, /* v_row_stride */ v_stride[v_rank - 3] as u32, @@ -364,8 +607,10 @@ impl FlashAttnVarLen { /* seqlen_k */ self.max_seqlen_k as u32, /* seqlen_q_rounded */ seqlen_q_rounded as u32, /* seqlen_k_rounded */ seqlen_k_rounded as u32, - /* is_causal */ causal, /* is_bf16 */ is_bf16, + /* is_causal */ is_causal, + /* window_size_left */ window_size_left, + /* window_size_right */ window_size_right, ) } @@ -440,13 +685,176 @@ pub fn flash_attn_varlen( softmax_scale: f32, causal: bool, ) -> Result<Tensor> { + let window_size_left = None; + let window_size_right = if causal { Some(0) } else { None }; + + let op = FlashAttnVarLen { + softmax_scale, + max_seqlen_q, + max_seqlen_k, + seqlens_q: seqlens_q.clone(), + seqlens_k: seqlens_k.clone(), + alibi_slopes: None, + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +#[allow(clippy::too_many_arguments)] +/// Flash-attention v2 layer with variable-length batching. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q. +/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v. +/// * `max_seqlen_q` - The maximum query sequence length for q in the batch. +/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch. +/// * `window_size_left` - Limit left attention to value tokens. +/// * `window_size_right` - Limit right attention to value tokens. +/// +/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`, +/// `seqlen_1 + seqlen_2`, etc. +/// +/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +pub fn flash_attn_varlen_windowed( + q: &Tensor, + k: &Tensor, + v: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + window_size_left: Option<usize>, + window_size_right: Option<usize>, +) -> Result<Tensor> { + let op = FlashAttnVarLen { + softmax_scale, + max_seqlen_q, + max_seqlen_k, + seqlens_q: seqlens_q.clone(), + seqlens_k: seqlens_k.clone(), + alibi_slopes: None, + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +#[allow(clippy::too_many_arguments)] +/// Flash-attention v2 layer with variable-length batching. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`. +/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q. +/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v. +/// * `max_seqlen_q` - The maximum query sequence length for q in the batch. +/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch. +/// +/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`, +/// `seqlen_1 + seqlen_2`, etc. +/// +/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`. +pub fn flash_attn_varlen_alibi( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + causal: bool, +) -> Result<Tensor> { + let window_size_left = None; + let window_size_right = if causal { Some(0) } else { None }; + + let op = FlashAttnVarLen { + softmax_scale, + max_seqlen_q, + max_seqlen_k, + seqlens_q: seqlens_q.clone(), + seqlens_k: seqlens_k.clone(), + alibi_slopes: Some(alibi_slopes.clone()), + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +#[allow(clippy::too_many_arguments)] +/// Flash-attention v2 layer with variable-length batching. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`. +/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q. +/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v. +/// * `max_seqlen_q` - The maximum query sequence length for q in the batch. +/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch. +/// * `window_size_left` - Limit left attention to value tokens. +/// * `window_size_right` - Limit right attention to value tokens. +/// +/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`, +/// `seqlen_1 + seqlen_2`, etc. +/// +/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +pub fn flash_attn_varlen_alibi_windowed( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + window_size_left: Option<usize>, + window_size_right: Option<usize>, +) -> Result<Tensor> { let op = FlashAttnVarLen { softmax_scale, - causal, max_seqlen_q, max_seqlen_k, seqlens_q: seqlens_q.clone(), seqlens_k: seqlens_k.clone(), + alibi_slopes: Some(alibi_slopes.clone()), + window_size_left, + window_size_right, }; q.apply_op3(k, v, op) } |