diff options
Diffstat (limited to 'candle-examples/examples/llama')
-rw-r--r-- | candle-examples/examples/llama/model.rs | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index f2f2fe28..0e850b6a 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -146,12 +146,19 @@ struct CausalSelfAttention { } #[cfg(feature = "flash-attn")] -fn flash_attn(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> { - q.custom_op3(k, v, candle_flash_attn::FlashHdim32Sm80) +fn flash_attn(softmax_scale: f32, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> { + q.custom_op3( + k, + v, + candle_flash_attn::FlashHdim32Sm80 { + softmax_scale, + causal: true, + }, + ) } #[cfg(not(feature = "flash-attn"))] -fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor) -> Result<Tensor> { +fn flash_attn(_: f32, _: &Tensor, _: &Tensor, _: &Tensor) -> Result<Tensor> { unimplemented!("compile with '--features flash-attn'") } @@ -213,7 +220,8 @@ impl CausalSelfAttention { let v = self.repeat_kv(v)?; let y = if self.use_flash_attn { - flash_attn(&q, &k, &v)? + let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); + flash_attn(softmax_scale, &q, &k, &v)? } else { let in_dtype = q.dtype(); let q = q.to_dtype(DType::F32)?; |