summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-01 21:04:07 +0200
committerGitHub <noreply@github.com>2023-09-01 20:04:07 +0100
commit19042962d5ae3ab17866522a0d2d99e873624441 (patch)
tree4d3a9f71d1769e326293c0f77354fddb11e8f8df
parent731e3ffb03fb1f1712202d4c790a88e7c8d9ecb3 (diff)
downloadcandle-19042962d5ae3ab17866522a0d2d99e873624441.tar.gz
candle-19042962d5ae3ab17866522a0d2d99e873624441.tar.bz2
candle-19042962d5ae3ab17866522a0d2d99e873624441.zip
Whisper fix (#711)
* Remove unnecessary file. * Whisper fix.
-rw-r--r--candle-examples/examples/whisper/main.rs5
-rw-r--r--candle-examples/examples/whisper/multilingual.rs7
2 files changed, 3 insertions, 9 deletions
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index fc64d458..5dd8ee20 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -146,11 +146,8 @@ impl Decoder {
tokens.push(language_token);
}
match self.task {
- Some(Task::Transcribe) => tokens.push(self.transcribe_token),
+ None | Some(Task::Transcribe) => tokens.push(self.transcribe_token),
Some(Task::Translate) => tokens.push(self.translate_token),
- None => {
- // Nothing in this case, same as the Python implementation.
- }
}
if !self.timestamps {
tokens.push(self.no_timestamps_token);
diff --git a/candle-examples/examples/whisper/multilingual.rs b/candle-examples/examples/whisper/multilingual.rs
index 3587b01a..bc0bae1f 100644
--- a/candle-examples/examples/whisper/multilingual.rs
+++ b/candle-examples/examples/whisper/multilingual.rs
@@ -117,11 +117,8 @@ pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor)
let audio_features = model.encoder.forward(&mel, true)?;
let tokens = Tensor::new(&[[sot_token]], device)?;
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
- let logits = model
- .decoder
- .forward(&tokens, &audio_features, true)?
- .i(0)?
- .i(0)?;
+ let ys = model.decoder.forward(&tokens, &audio_features, true)?;
+ let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
let logits = logits.index_select(&language_token_ids, 0)?;
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
let probs = probs.to_vec1::<f32>()?;