summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/llama')
-rw-r--r--candle-examples/examples/llama/model.rs16
1 files changed, 12 insertions, 4 deletions
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs
index f2f2fe28..0e850b6a 100644
--- a/candle-examples/examples/llama/model.rs
+++ b/candle-examples/examples/llama/model.rs
@@ -146,12 +146,19 @@ struct CausalSelfAttention {
}
#[cfg(feature = "flash-attn")]
-fn flash_attn(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
- q.custom_op3(k, v, candle_flash_attn::FlashHdim32Sm80)
+fn flash_attn(softmax_scale: f32, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
+ q.custom_op3(
+ k,
+ v,
+ candle_flash_attn::FlashHdim32Sm80 {
+ softmax_scale,
+ causal: true,
+ },
+ )
}
#[cfg(not(feature = "flash-attn"))]
-fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor) -> Result<Tensor> {
+fn flash_attn(_: f32, _: &Tensor, _: &Tensor, _: &Tensor) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}
@@ -213,7 +220,8 @@ impl CausalSelfAttention {
let v = self.repeat_kv(v)?;
let y = if self.use_flash_attn {
- flash_attn(&q, &k, &v)?
+ let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
+ flash_attn(softmax_scale, &q, &k, &v)?
} else {
let in_dtype = q.dtype();
let q = q.to_dtype(DType::F32)?;