diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-28 13:13:01 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-28 13:13:01 +0100 |
commit | 3eb2bc6d07f192a5ce73ab6964745275f2c15213 (patch) | |
tree | e5a682d0e40f3c258f668652082ff7fa45918e32 /candle-examples/examples/musicgen | |
parent | 68eab38de6e5cabf17159a5dcf45ec703fbea441 (diff) | |
download | candle-3eb2bc6d07f192a5ce73ab6964745275f2c15213.tar.gz candle-3eb2bc6d07f192a5ce73ab6964745275f2c15213.tar.bz2 candle-3eb2bc6d07f192a5ce73ab6964745275f2c15213.zip |
Softmax numerical stability. (#267)
* Softmax numerical stability.
* Fix the flash-attn test.
Diffstat (limited to 'candle-examples/examples/musicgen')
-rw-r--r-- | candle-examples/examples/musicgen/musicgen_model.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/musicgen/t5_model.rs | 2 |
2 files changed, 2 insertions, 2 deletions
diff --git a/candle-examples/examples/musicgen/musicgen_model.rs b/candle-examples/examples/musicgen/musicgen_model.rs index 212f6818..01266e63 100644 --- a/candle-examples/examples/musicgen/musicgen_model.rs +++ b/candle-examples/examples/musicgen/musicgen_model.rs @@ -187,7 +187,7 @@ impl MusicgenAttention { let attn_weights = attn_weights .reshape((b_sz, self.num_heads, tgt_len, src_len))? .broadcast_add(attention_mask)?; - let attn_weights = attn_weights.softmax(D::Minus1)?; + let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; // TODO: layer_head_mask? let attn_output = attn_weights .matmul(&value_states)? diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs index 61c0a1bb..ef65df39 100644 --- a/candle-examples/examples/musicgen/t5_model.rs +++ b/candle-examples/examples/musicgen/t5_model.rs @@ -223,7 +223,7 @@ impl T5Attention { .transpose(1, 2)?; let scores = q.matmul(&k.t()?)?; // TODO: position_bias_masked - let attn_weights = scores.softmax(D::Minus1)?; + let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?; let attn_output = attn_weights.matmul(&v)?; let attn_output = self.o.forward(&attn_output)?; Ok(attn_output) |