diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-28 13:13:01 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-28 13:13:01 +0100 |
commit | 3eb2bc6d07f192a5ce73ab6964745275f2c15213 (patch) | |
tree | e5a682d0e40f3c258f668652082ff7fa45918e32 /candle-examples/examples/bigcode/model.rs | |
parent | 68eab38de6e5cabf17159a5dcf45ec703fbea441 (diff) | |
download | candle-3eb2bc6d07f192a5ce73ab6964745275f2c15213.tar.gz candle-3eb2bc6d07f192a5ce73ab6964745275f2c15213.tar.bz2 candle-3eb2bc6d07f192a5ce73ab6964745275f2c15213.zip |
Softmax numerical stability. (#267)
* Softmax numerical stability.
* Fix the flash-attn test.
Diffstat (limited to 'candle-examples/examples/bigcode/model.rs')
-rw-r--r-- | candle-examples/examples/bigcode/model.rs | 12 |
1 files changed, 1 insertions, 11 deletions
diff --git a/candle-examples/examples/bigcode/model.rs b/candle-examples/examples/bigcode/model.rs index 3f68a5be..12993e2d 100644 --- a/candle-examples/examples/bigcode/model.rs +++ b/candle-examples/examples/bigcode/model.rs @@ -30,16 +30,6 @@ fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> { Ok(mask) } -// TODO: Use a numerically stable implementation by default. -fn softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> { - let d = d.to_index(xs.shape(), "log-softmax")?; - let max = xs.max_keepdim(d)?; - let diff = xs.broadcast_sub(&max)?; - let num = diff.exp()?; - let den = num.sum_keepdim(d)?; - num.broadcast_div(&den) -} - #[derive(Debug)] pub struct Config { pub vocab_size: usize, @@ -192,7 +182,7 @@ impl Attention { let mask_value = Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?; let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?; - let attn_weights = softmax(&attn_weights, D::Minus1)?; + let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; let value = value.contiguous()?; let attn_output = if self.multi_query { attn_weights |