summaryrefslogtreecommitdiff
path: root/candle-flash-attn/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-flash-attn/src/lib.rs')
-rw-r--r--candle-flash-attn/src/lib.rs434
1 files changed, 421 insertions, 13 deletions
diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs
index 3395bd0d..21a06b5e 100644
--- a/candle-flash-attn/src/lib.rs
+++ b/candle-flash-attn/src/lib.rs
@@ -3,12 +3,14 @@ mod ffi;
use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr;
-use candle::{CpuStorage, Layout, Result, Shape, Tensor};
+use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
use half::{bf16, f16};
pub struct FlashAttn {
pub softmax_scale: f32,
- pub causal: bool,
+ pub alibi_slopes: Option<Tensor>,
+ pub window_size_left: Option<usize>,
+ pub window_size_right: Option<usize>,
}
fn round_multiple(x: usize, m: usize) -> usize {
@@ -85,6 +87,51 @@ impl FlashAttn {
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
}
+ let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
+ if alibi_slopes.dtype() != DType::F32 {
+ candle::bail!(
+ "DType mismatch alibi_slopes {:?}, expected {:?}",
+ alibi_slopes.dtype(),
+ DType::F32
+ );
+ }
+
+ let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout();
+
+ if num_heads != alibi_slopes_layout.shape().dims1()? {
+ candle::bail!(
+ "shape mismatch alibi_slopes {:?}, expected {:?}",
+ alibi_slopes_layout.shape(),
+ (num_heads)
+ );
+ }
+
+ let alibi_slopes = match &*alibi_slopes {
+ candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
+ _ => candle::bail!("alibi_slopes must be a cuda tensor"),
+ };
+
+ let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
+
+ *alibi_slopes.device_ptr() as *const core::ffi::c_void
+ } else {
+ std::ptr::null()
+ };
+
+ // if window_size_left > self.max_seqlen_k or None => -1
+ let mut window_size_left = self
+ .window_size_left
+ .filter(|v| v <= &seqlen_k)
+ .map(|v| v as i32)
+ .unwrap_or(-1);
+
+ // if window_size_right > self.max_seqlen_k or None => -1
+ let mut window_size_right = self
+ .window_size_right
+ .filter(|v| v <= &seqlen_k)
+ .map(|v| v as i32)
+ .unwrap_or(-1);
+
let head_size = round_multiple(head_size_og, 8);
let head_size_rounded = round_multiple(head_size, 32);
let seqlen_q_rounded = round_multiple(seqlen_q, 128);
@@ -94,9 +141,22 @@ impl FlashAttn {
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
- let causal = if self.causal { 1 } else { 0 };
let is_bf16 = if is_bf16 { 1 } else { 0 };
+ // Causal is the special case where window_size_right == 0 and window_size_left < 0.
+ // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
+ let is_causal = if window_size_left < 0 && window_size_right == 0 {
+ 1
+ } else {
+ 0
+ };
+ if window_size_left < 0 && window_size_right >= 0 {
+ window_size_left = seqlen_k as i32;
+ }
+ if window_size_left >= 0 && window_size_right < 0 {
+ window_size_right = seqlen_k as i32;
+ }
+
unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
@@ -109,12 +169,14 @@ impl FlashAttn {
v_ptr,
dst_ptr,
softmax_lse_ptr,
+ /* alibi_slopes_ptr */ alibi_slopes_ptr,
/* cu_seqlens_q_ptr */ std::ptr::null(),
/* cu_seqlens_k_ptr */ std::ptr::null(),
/* q_batch_stride */ q_stride[0] as u32,
/* k_batch_stride */ k_stride[0] as u32,
/* v_batch_stride */ v_stride[0] as u32,
/* o_batch_stride */ o_stride[0] as u32,
+ /* alibi_slopes_batch_stride */ 0,
/* q_row_stride */ q_stride[q_rank - 3] as u32,
/* k_row_stride */ k_stride[k_rank - 3] as u32,
/* v_row_stride */ v_stride[v_rank - 3] as u32,
@@ -133,8 +195,10 @@ impl FlashAttn {
/* seqlen_k */ seqlen_k as u32,
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
- /* is_causal */ causal,
/* is_bf16 */ is_bf16,
+ /* is_causal */ is_causal,
+ /* window_size_left */ window_size_left,
+ /* window_size_right */ window_size_right,
)
}
@@ -197,20 +261,137 @@ pub fn flash_attn(
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
+ let window_size_left = None;
+ let window_size_right = if causal { Some(0) } else { None };
+
let op = FlashAttn {
softmax_scale,
- causal,
+ alibi_slopes: None,
+ window_size_left,
+ window_size_right,
};
q.apply_op3(k, v, op)
}
-struct FlashAttnVarLen {
+/// Flash-attention v2 layer.
+///
+/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
+/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
+/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
+///
+/// # Arguments
+///
+/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
+/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
+/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
+/// * `window_size_left` - Limit left attention to value tokens.
+/// * `window_size_right` - Limit right attention to value tokens.
+///
+/// # Causal mask
+///
+/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
+/// of `Q @ K^T`
+///
+/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
+pub fn flash_attn_windowed(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ softmax_scale: f32,
+ window_size_left: Option<usize>,
+ window_size_right: Option<usize>,
+) -> Result<Tensor> {
+ let op = FlashAttn {
+ softmax_scale,
+ alibi_slopes: None,
+ window_size_left,
+ window_size_right,
+ };
+ q.apply_op3(k, v, op)
+}
+
+/// Flash-attention v2 layer.
+///
+/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
+/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
+/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
+///
+/// # Arguments
+///
+/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
+/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
+/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
+/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
+///
+/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
+pub fn flash_attn_alibi(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ alibi_slopes: &Tensor,
softmax_scale: f32,
causal: bool,
- max_seqlen_q: usize,
- max_seqlen_k: usize,
- seqlens_q: Tensor,
- seqlens_k: Tensor,
+) -> Result<Tensor> {
+ let window_size_left = None;
+ let window_size_right = if causal { Some(0) } else { None };
+
+ let op = FlashAttn {
+ softmax_scale,
+ alibi_slopes: Some(alibi_slopes.clone()),
+ window_size_left,
+ window_size_right,
+ };
+ q.apply_op3(k, v, op)
+}
+
+/// Flash-attention v2 layer.
+///
+/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
+/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
+/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
+///
+/// # Arguments
+///
+/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
+/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
+/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
+/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
+/// * `window_size_left` - Limit left attention to value tokens.
+/// * `window_size_right` - Limit right attention to value tokens.
+///
+/// # Causal mask
+///
+/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
+/// of `Q @ K^T`
+///
+/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
+pub fn flash_attn_alibi_windowed(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ alibi_slopes: &Tensor,
+ softmax_scale: f32,
+ window_size_left: Option<usize>,
+ window_size_right: Option<usize>,
+) -> Result<Tensor> {
+ let op = FlashAttn {
+ softmax_scale,
+ alibi_slopes: Some(alibi_slopes.clone()),
+ window_size_left,
+ window_size_right,
+ };
+ q.apply_op3(k, v, op)
+}
+
+struct FlashAttnVarLen {
+ pub softmax_scale: f32,
+ pub max_seqlen_q: usize,
+ pub max_seqlen_k: usize,
+ pub seqlens_q: Tensor,
+ pub seqlens_k: Tensor,
+ pub alibi_slopes: Option<Tensor>,
+ pub window_size_left: Option<usize>,
+ pub window_size_right: Option<usize>,
}
impl FlashAttnVarLen {
@@ -311,7 +492,54 @@ impl FlashAttnVarLen {
if nseqlens_k != nseqlens_q {
candle::bail!("seqlens_q and seqlens_k should have the same number of elements {nseqlens_q} <> {nseqlens_k}")
}
+
let batch_size = nseqlens_q - 1;
+
+ let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
+ if alibi_slopes.dtype() != DType::F32 {
+ candle::bail!(
+ "DType mismatch alibi_slopes {:?}, expected {:?}",
+ alibi_slopes.dtype(),
+ DType::F32
+ );
+ }
+
+ let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout();
+
+ if num_heads != alibi_slopes_layout.shape().dims1()? {
+ candle::bail!(
+ "shape mismatch alibi_slopes {:?}, expected {:?}",
+ alibi_slopes_layout.shape(),
+ (num_heads)
+ );
+ }
+
+ let alibi_slopes = match &*alibi_slopes {
+ candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
+ _ => candle::bail!("alibi_slopes must be a cuda tensor"),
+ };
+
+ let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
+
+ *alibi_slopes.device_ptr() as *const core::ffi::c_void
+ } else {
+ std::ptr::null()
+ };
+
+ // if window_size_left > self.max_seqlen_k or None => -1
+ let mut window_size_left = self
+ .window_size_left
+ .filter(|v| v <= &self.max_seqlen_k)
+ .map(|v| v as i32)
+ .unwrap_or(-1);
+
+ // if window_size_right > self.max_seqlen_k or None => -1
+ let mut window_size_right = self
+ .window_size_right
+ .filter(|v| v <= &self.max_seqlen_k)
+ .map(|v| v as i32)
+ .unwrap_or(-1);
+
let head_size = round_multiple(head_size_og, 8);
let head_size_rounded = round_multiple(head_size, 32);
let seqlen_q_rounded = round_multiple(self.max_seqlen_q, 128);
@@ -323,9 +551,22 @@ impl FlashAttnVarLen {
.alloc_zeros::<f32>(batch_size * num_heads * self.max_seqlen_q)
.w()?;
- let causal = if self.causal { 1 } else { 0 };
let is_bf16 = if is_bf16 { 1 } else { 0 };
+ // Causal is the special case where window_size_right == 0 and window_size_left < 0.
+ // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
+ let is_causal = if window_size_left < 0 && window_size_right == 0 {
+ 1
+ } else {
+ 0
+ };
+ if window_size_left < 0 && window_size_right >= 0 {
+ window_size_left = self.max_seqlen_k as i32;
+ }
+ if window_size_left >= 0 && window_size_right < 0 {
+ window_size_right = self.max_seqlen_k as i32;
+ }
+
unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
@@ -340,12 +581,14 @@ impl FlashAttnVarLen {
v_ptr,
dst_ptr,
softmax_lse_ptr,
+ /* alibi_slopes_ptr */ alibi_slopes_ptr,
/* cu_seqlens_q_ptr */ seqlens_q_ptr,
/* cu_seqlens_k_ptr */ seqlens_k_ptr,
/* q_batch_stride */ 0,
/* k_batch_stride */ 0,
/* v_batch_stride */ 0,
/* o_batch_stride */ 0,
+ /* alibi_slopes_batch_stride */ 0,
/* q_row_stride */ q_stride[q_rank - 3] as u32,
/* k_row_stride */ k_stride[k_rank - 3] as u32,
/* v_row_stride */ v_stride[v_rank - 3] as u32,
@@ -364,8 +607,10 @@ impl FlashAttnVarLen {
/* seqlen_k */ self.max_seqlen_k as u32,
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
- /* is_causal */ causal,
/* is_bf16 */ is_bf16,
+ /* is_causal */ is_causal,
+ /* window_size_left */ window_size_left,
+ /* window_size_right */ window_size_right,
)
}
@@ -440,13 +685,176 @@ pub fn flash_attn_varlen(
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
+ let window_size_left = None;
+ let window_size_right = if causal { Some(0) } else { None };
+
+ let op = FlashAttnVarLen {
+ softmax_scale,
+ max_seqlen_q,
+ max_seqlen_k,
+ seqlens_q: seqlens_q.clone(),
+ seqlens_k: seqlens_k.clone(),
+ alibi_slopes: None,
+ window_size_left,
+ window_size_right,
+ };
+ q.apply_op3(k, v, op)
+}
+
+#[allow(clippy::too_many_arguments)]
+/// Flash-attention v2 layer with variable-length batching.
+///
+/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
+/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
+/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
+///
+/// # Arguments
+///
+/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
+/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
+/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
+/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
+/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
+/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
+/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
+/// * `window_size_left` - Limit left attention to value tokens.
+/// * `window_size_right` - Limit right attention to value tokens.
+///
+/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
+/// `seqlen_1 + seqlen_2`, etc.
+///
+/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
+///
+/// # Causal mask
+///
+/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
+/// of `Q @ K^T`
+pub fn flash_attn_varlen_windowed(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ seqlens_q: &Tensor,
+ seqlens_k: &Tensor,
+ max_seqlen_q: usize,
+ max_seqlen_k: usize,
+ softmax_scale: f32,
+ window_size_left: Option<usize>,
+ window_size_right: Option<usize>,
+) -> Result<Tensor> {
+ let op = FlashAttnVarLen {
+ softmax_scale,
+ max_seqlen_q,
+ max_seqlen_k,
+ seqlens_q: seqlens_q.clone(),
+ seqlens_k: seqlens_k.clone(),
+ alibi_slopes: None,
+ window_size_left,
+ window_size_right,
+ };
+ q.apply_op3(k, v, op)
+}
+
+#[allow(clippy::too_many_arguments)]
+/// Flash-attention v2 layer with variable-length batching.
+///
+/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
+/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
+/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
+///
+/// # Arguments
+///
+/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
+/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
+/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
+/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
+/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
+/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
+/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
+/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
+///
+/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
+/// `seqlen_1 + seqlen_2`, etc.
+///
+/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
+pub fn flash_attn_varlen_alibi(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ alibi_slopes: &Tensor,
+ seqlens_q: &Tensor,
+ seqlens_k: &Tensor,
+ max_seqlen_q: usize,
+ max_seqlen_k: usize,
+ softmax_scale: f32,
+ causal: bool,
+) -> Result<Tensor> {
+ let window_size_left = None;
+ let window_size_right = if causal { Some(0) } else { None };
+
+ let op = FlashAttnVarLen {
+ softmax_scale,
+ max_seqlen_q,
+ max_seqlen_k,
+ seqlens_q: seqlens_q.clone(),
+ seqlens_k: seqlens_k.clone(),
+ alibi_slopes: Some(alibi_slopes.clone()),
+ window_size_left,
+ window_size_right,
+ };
+ q.apply_op3(k, v, op)
+}
+
+#[allow(clippy::too_many_arguments)]
+/// Flash-attention v2 layer with variable-length batching.
+///
+/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
+/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
+/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
+///
+/// # Arguments
+///
+/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
+/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
+/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
+/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
+/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
+/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
+/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
+/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
+/// * `window_size_left` - Limit left attention to value tokens.
+/// * `window_size_right` - Limit right attention to value tokens.
+///
+/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
+/// `seqlen_1 + seqlen_2`, etc.
+///
+/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
+///
+/// # Causal mask
+///
+/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
+/// of `Q @ K^T`
+pub fn flash_attn_varlen_alibi_windowed(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ alibi_slopes: &Tensor,
+ seqlens_q: &Tensor,
+ seqlens_k: &Tensor,
+ max_seqlen_q: usize,
+ max_seqlen_k: usize,
+ softmax_scale: f32,
+ window_size_left: Option<usize>,
+ window_size_right: Option<usize>,
+) -> Result<Tensor> {
let op = FlashAttnVarLen {
softmax_scale,
- causal,
max_seqlen_q,
max_seqlen_k,
seqlens_q: seqlens_q.clone(),
seqlens_k: seqlens_k.clone(),
+ alibi_slopes: Some(alibi_slopes.clone()),
+ window_size_left,
+ window_size_right,
};
q.apply_op3(k, v, op)
}