diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-26 10:13:40 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-26 10:13:40 +0100 |
commit | fa2b64d678ca83e2fbc3dabdecffbc778d5b067d (patch) | |
tree | da0643095bb790867d08dbd81ebdbe56ba681364 /candle-examples/examples/llama | |
parent | e40b150bbee980601f0a37ba4646216ee48bfbfb (diff) | |
download | candle-fa2b64d678ca83e2fbc3dabdecffbc778d5b067d.tar.gz candle-fa2b64d678ca83e2fbc3dabdecffbc778d5b067d.tar.bz2 candle-fa2b64d678ca83e2fbc3dabdecffbc778d5b067d.zip |
Proper flash-attn parameters. (#244)
* Proper flash-attn parameters.
* Set the flash attention parameters.
* Add more validations.
* Setup the o_ flash attn parameters.
* More flash-attn support.
* Set more flash attn parameters.
Diffstat (limited to 'candle-examples/examples/llama')
-rw-r--r-- | candle-examples/examples/llama/model.rs | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index f2f2fe28..0e850b6a 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -146,12 +146,19 @@ struct CausalSelfAttention { } #[cfg(feature = "flash-attn")] -fn flash_attn(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> { - q.custom_op3(k, v, candle_flash_attn::FlashHdim32Sm80) +fn flash_attn(softmax_scale: f32, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> { + q.custom_op3( + k, + v, + candle_flash_attn::FlashHdim32Sm80 { + softmax_scale, + causal: true, + }, + ) } #[cfg(not(feature = "flash-attn"))] -fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor) -> Result<Tensor> { +fn flash_attn(_: f32, _: &Tensor, _: &Tensor, _: &Tensor) -> Result<Tensor> { unimplemented!("compile with '--features flash-attn'") } @@ -213,7 +220,8 @@ impl CausalSelfAttention { let v = self.repeat_kv(v)?; let y = if self.use_flash_attn { - flash_attn(&q, &k, &v)? + let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); + flash_attn(softmax_scale, &q, &k, &v)? } else { let in_dtype = q.dtype(); let q = q.to_dtype(DType::F32)?; |