summaryrefslogtreecommitdiff
path: root/candle-flash-attn/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-flash-attn/src/lib.rs')
-rw-r--r--candle-flash-attn/src/lib.rs115
1 files changed, 115 insertions, 0 deletions
diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs
index f171a986..22a6f1d6 100644
--- a/candle-flash-attn/src/lib.rs
+++ b/candle-flash-attn/src/lib.rs
@@ -11,6 +11,7 @@ pub struct FlashAttn {
pub alibi_slopes: Option<Tensor>,
pub window_size_left: Option<usize>,
pub window_size_right: Option<usize>,
+ pub softcap: Option<f32>,
}
fn round_multiple(x: usize, m: usize) -> usize {
@@ -201,6 +202,7 @@ impl FlashAttn {
/* is_causal */ is_causal,
/* window_size_left */ window_size_left,
/* window_size_right */ window_size_right,
+ /* softcap */ self.softcap.unwrap_or(0f32),
)
}
@@ -271,6 +273,7 @@ pub fn flash_attn(
alibi_slopes: None,
window_size_left,
window_size_right,
+ softcap: None,
};
q.apply_op3(k, v, op)
}
@@ -308,6 +311,7 @@ pub fn flash_attn_windowed(
alibi_slopes: None,
window_size_left,
window_size_right,
+ softcap: None,
};
q.apply_op3(k, v, op)
}
@@ -342,6 +346,7 @@ pub fn flash_attn_alibi(
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
+ softcap: None,
};
q.apply_op3(k, v, op)
}
@@ -381,6 +386,52 @@ pub fn flash_attn_alibi_windowed(
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
+ softcap: None,
+ };
+ q.apply_op3(k, v, op)
+}
+
+/// Flash-attention v2 layer.
+///
+/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
+/// Multi-query and grouped-query attention are supported by using tensors `k` and `v` with fewer heads
+/// than `q`. The number of heads in `k` and `v` must be divisible by the number of heads in `q`.
+///
+/// # Arguments
+///
+/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
+/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
+/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
+/// * `alibi_slopes` - Optional alibi slopes tensor with shape `(num_heads_q)`.
+/// * `softmax_scale` - Scaling factor for the softmax operation.
+/// * `window_size_left` - Optional limit on left attention to value tokens.
+/// * `window_size_right` - Optional limit on right attention to value tokens.
+/// * `softcap` - Gemma style softcap the attention logits before the softmax.
+///
+/// # Causal Mask
+///
+/// Setting `window_size_left=None` and `window_size_right=Some(0)` applies a causal mask to the result
+/// of `Q @ K^T`.
+///
+/// # Returns
+///
+/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
+pub fn flash_attn_alibi_windowed_softcap(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ alibi_slopes: Option<&Tensor>,
+ softmax_scale: f32,
+ window_size_left: Option<usize>,
+ window_size_right: Option<usize>,
+ softcap: f32,
+) -> Result<Tensor> {
+ let op = FlashAttn {
+ softmax_scale,
+ alibi_slopes: alibi_slopes.cloned(),
+ window_size_left,
+ window_size_right,
+ softcap: Some(softcap),
};
q.apply_op3(k, v, op)
}
@@ -394,6 +445,7 @@ struct FlashAttnVarLen {
pub alibi_slopes: Option<Tensor>,
pub window_size_left: Option<usize>,
pub window_size_right: Option<usize>,
+ pub softcap: Option<f32>,
}
impl FlashAttnVarLen {
@@ -613,6 +665,7 @@ impl FlashAttnVarLen {
/* is_causal */ is_causal,
/* window_size_left */ window_size_left,
/* window_size_right */ window_size_right,
+ /* softcap */ self.softcap.unwrap_or(0.0),
)
}
@@ -699,6 +752,7 @@ pub fn flash_attn_varlen(
alibi_slopes: None,
window_size_left,
window_size_right,
+ softcap: None,
};
q.apply_op3(k, v, op)
}
@@ -752,6 +806,7 @@ pub fn flash_attn_varlen_windowed(
alibi_slopes: None,
window_size_left,
window_size_right,
+ softcap: None,
};
q.apply_op3(k, v, op)
}
@@ -802,6 +857,7 @@ pub fn flash_attn_varlen_alibi(
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
+ softcap: None,
};
q.apply_op3(k, v, op)
}
@@ -857,6 +913,65 @@ pub fn flash_attn_varlen_alibi_windowed(
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
+ softcap: None,
+ };
+ q.apply_op3(k, v, op)
+}
+
+#[allow(clippy::too_many_arguments)]
+/// Flash-attention v2 layer with variable-length batching.
+///
+/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
+/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
+/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
+///
+/// # Arguments
+///
+/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
+/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
+/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
+/// * `alibi_slopes` - Option, alibi slopes tensor with shape `(num_heads_q)`.
+/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
+/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
+/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
+/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
+/// * `window_size_left` - Option, limit left attention to value tokens.
+/// * `window_size_right` - Option, limit right attention to value tokens.
+/// * `softcap` - Gemma style softcap the attention logits before the softmax.
+///
+/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
+/// `seqlen_1 + seqlen_2`, etc.
+///
+/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
+///
+/// # Causal mask
+///
+/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
+/// of `Q @ K^T`
+pub fn flash_attn_varlen_alibi_windowed_softcap(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ alibi_slopes: Option<&Tensor>,
+ seqlens_q: &Tensor,
+ seqlens_k: &Tensor,
+ max_seqlen_q: usize,
+ max_seqlen_k: usize,
+ softmax_scale: f32,
+ window_size_left: Option<usize>,
+ window_size_right: Option<usize>,
+ softcap: f32,
+) -> Result<Tensor> {
+ let op = FlashAttnVarLen {
+ softmax_scale,
+ max_seqlen_q,
+ max_seqlen_k,
+ seqlens_q: seqlens_q.clone(),
+ seqlens_k: seqlens_k.clone(),
+ alibi_slopes: alibi_slopes.cloned(),
+ window_size_left,
+ window_size_right,
+ softcap: Some(softcap),
};
q.apply_op3(k, v, op)
}