diff options
Diffstat (limited to 'candle-examples/examples/whisper/model.rs')
-rw-r--r-- | candle-examples/examples/whisper/model.rs | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs index 330b2a00..4d80c0c8 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-examples/examples/whisper/model.rs @@ -2,7 +2,7 @@ // back when using RUST_LIB_BACKTRACE=1. use anyhow::Result; use candle::{Device, Tensor}; -use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder}; +use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder}; use serde::Deserialize; // The names in comments correspond to the original implementation: @@ -154,7 +154,7 @@ impl MultiHeadAttention { let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?; qk = qk.broadcast_add(&mask)? } - let w = qk.softmax(candle::D::Minus1)?; + let w = softmax(&qk, candle::D::Minus1)?; let wv = w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?; Ok(wv) } |