summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/whisper/main.rs5
-rw-r--r--candle-examples/examples/whisper/multilingual.rs7
2 files changed, 3 insertions, 9 deletions
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index fc64d458..5dd8ee20 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -146,11 +146,8 @@ impl Decoder {
tokens.push(language_token);
}
match self.task {
- Some(Task::Transcribe) => tokens.push(self.transcribe_token),
+ None | Some(Task::Transcribe) => tokens.push(self.transcribe_token),
Some(Task::Translate) => tokens.push(self.translate_token),
- None => {
- // Nothing in this case, same as the Python implementation.
- }
}
if !self.timestamps {
tokens.push(self.no_timestamps_token);
diff --git a/candle-examples/examples/whisper/multilingual.rs b/candle-examples/examples/whisper/multilingual.rs
index 3587b01a..bc0bae1f 100644
--- a/candle-examples/examples/whisper/multilingual.rs
+++ b/candle-examples/examples/whisper/multilingual.rs
@@ -117,11 +117,8 @@ pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor)
let audio_features = model.encoder.forward(&mel, true)?;
let tokens = Tensor::new(&[[sot_token]], device)?;
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
- let logits = model
- .decoder
- .forward(&tokens, &audio_features, true)?
- .i(0)?
- .i(0)?;
+ let ys = model.decoder.forward(&tokens, &audio_features, true)?;
+ let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
let logits = logits.index_select(&language_token_ids, 0)?;
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
let probs = probs.to_vec1::<f32>()?;