summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-26 10:13:40 +0100
committerGitHub <noreply@github.com>2023-07-26 10:13:40 +0100
commitfa2b64d678ca83e2fbc3dabdecffbc778d5b067d (patch)
treeda0643095bb790867d08dbd81ebdbe56ba681364 /candle-examples/examples/llama
parente40b150bbee980601f0a37ba4646216ee48bfbfb (diff)
downloadcandle-fa2b64d678ca83e2fbc3dabdecffbc778d5b067d.tar.gz
candle-fa2b64d678ca83e2fbc3dabdecffbc778d5b067d.tar.bz2
candle-fa2b64d678ca83e2fbc3dabdecffbc778d5b067d.zip
Proper flash-attn parameters. (#244)
* Proper flash-attn parameters. * Set the flash attention parameters. * Add more validations. * Setup the o_ flash attn parameters. * More flash-attn support. * Set more flash attn parameters.
Diffstat (limited to 'candle-examples/examples/llama')
-rw-r--r--candle-examples/examples/llama/model.rs16
1 files changed, 12 insertions, 4 deletions
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs
index f2f2fe28..0e850b6a 100644
--- a/candle-examples/examples/llama/model.rs
+++ b/candle-examples/examples/llama/model.rs
@@ -146,12 +146,19 @@ struct CausalSelfAttention {
}
#[cfg(feature = "flash-attn")]
-fn flash_attn(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
- q.custom_op3(k, v, candle_flash_attn::FlashHdim32Sm80)
+fn flash_attn(softmax_scale: f32, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
+ q.custom_op3(
+ k,
+ v,
+ candle_flash_attn::FlashHdim32Sm80 {
+ softmax_scale,
+ causal: true,
+ },
+ )
}
#[cfg(not(feature = "flash-attn"))]
-fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor) -> Result<Tensor> {
+fn flash_attn(_: f32, _: &Tensor, _: &Tensor, _: &Tensor) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}
@@ -213,7 +220,8 @@ impl CausalSelfAttention {
let v = self.repeat_kv(v)?;
let y = if self.use_flash_attn {
- flash_attn(&q, &k, &v)?
+ let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
+ flash_attn(softmax_scale, &q, &k, &v)?
} else {
let in_dtype = q.dtype();
let q = q.to_dtype(DType::F32)?;