summaryrefslogtreecommitdiff
path: root/candle-examples/examples/bigcode/model.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-28 13:13:01 +0100
committerGitHub <noreply@github.com>2023-07-28 13:13:01 +0100
commit3eb2bc6d07f192a5ce73ab6964745275f2c15213 (patch)
treee5a682d0e40f3c258f668652082ff7fa45918e32 /candle-examples/examples/bigcode/model.rs
parent68eab38de6e5cabf17159a5dcf45ec703fbea441 (diff)
downloadcandle-3eb2bc6d07f192a5ce73ab6964745275f2c15213.tar.gz
candle-3eb2bc6d07f192a5ce73ab6964745275f2c15213.tar.bz2
candle-3eb2bc6d07f192a5ce73ab6964745275f2c15213.zip
Softmax numerical stability. (#267)
* Softmax numerical stability. * Fix the flash-attn test.
Diffstat (limited to 'candle-examples/examples/bigcode/model.rs')
-rw-r--r--candle-examples/examples/bigcode/model.rs12
1 files changed, 1 insertions, 11 deletions
diff --git a/candle-examples/examples/bigcode/model.rs b/candle-examples/examples/bigcode/model.rs
index 3f68a5be..12993e2d 100644
--- a/candle-examples/examples/bigcode/model.rs
+++ b/candle-examples/examples/bigcode/model.rs
@@ -30,16 +30,6 @@ fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> {
Ok(mask)
}
-// TODO: Use a numerically stable implementation by default.
-fn softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
- let d = d.to_index(xs.shape(), "log-softmax")?;
- let max = xs.max_keepdim(d)?;
- let diff = xs.broadcast_sub(&max)?;
- let num = diff.exp()?;
- let den = num.sum_keepdim(d)?;
- num.broadcast_div(&den)
-}
-
#[derive(Debug)]
pub struct Config {
pub vocab_size: usize,
@@ -192,7 +182,7 @@ impl Attention {
let mask_value =
Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?;
let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?;
- let attn_weights = softmax(&attn_weights, D::Minus1)?;
+ let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
let value = value.contiguous()?;
let attn_output = if self.multi_query {
attn_weights