summaryrefslogtreecommitdiff
path: root/candle-examples/examples/musicgen
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-28 13:13:01 +0100
committerGitHub <noreply@github.com>2023-07-28 13:13:01 +0100
commit3eb2bc6d07f192a5ce73ab6964745275f2c15213 (patch)
treee5a682d0e40f3c258f668652082ff7fa45918e32 /candle-examples/examples/musicgen
parent68eab38de6e5cabf17159a5dcf45ec703fbea441 (diff)
downloadcandle-3eb2bc6d07f192a5ce73ab6964745275f2c15213.tar.gz
candle-3eb2bc6d07f192a5ce73ab6964745275f2c15213.tar.bz2
candle-3eb2bc6d07f192a5ce73ab6964745275f2c15213.zip
Softmax numerical stability. (#267)
* Softmax numerical stability. * Fix the flash-attn test.
Diffstat (limited to 'candle-examples/examples/musicgen')
-rw-r--r--candle-examples/examples/musicgen/musicgen_model.rs2
-rw-r--r--candle-examples/examples/musicgen/t5_model.rs2
2 files changed, 2 insertions, 2 deletions
diff --git a/candle-examples/examples/musicgen/musicgen_model.rs b/candle-examples/examples/musicgen/musicgen_model.rs
index 212f6818..01266e63 100644
--- a/candle-examples/examples/musicgen/musicgen_model.rs
+++ b/candle-examples/examples/musicgen/musicgen_model.rs
@@ -187,7 +187,7 @@ impl MusicgenAttention {
let attn_weights = attn_weights
.reshape((b_sz, self.num_heads, tgt_len, src_len))?
.broadcast_add(attention_mask)?;
- let attn_weights = attn_weights.softmax(D::Minus1)?;
+ let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
// TODO: layer_head_mask?
let attn_output = attn_weights
.matmul(&value_states)?
diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs
index 61c0a1bb..ef65df39 100644
--- a/candle-examples/examples/musicgen/t5_model.rs
+++ b/candle-examples/examples/musicgen/t5_model.rs
@@ -223,7 +223,7 @@ impl T5Attention {
.transpose(1, 2)?;
let scores = q.matmul(&k.t()?)?;
// TODO: position_bias_masked
- let attn_weights = scores.softmax(D::Minus1)?;
+ let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
let attn_output = attn_weights.matmul(&v)?;
let attn_output = self.o.forward(&attn_output)?;
Ok(attn_output)