diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-01 21:04:07 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-01 20:04:07 +0100 |
commit | 19042962d5ae3ab17866522a0d2d99e873624441 (patch) | |
tree | 4d3a9f71d1769e326293c0f77354fddb11e8f8df | |
parent | 731e3ffb03fb1f1712202d4c790a88e7c8d9ecb3 (diff) | |
download | candle-19042962d5ae3ab17866522a0d2d99e873624441.tar.gz candle-19042962d5ae3ab17866522a0d2d99e873624441.tar.bz2 candle-19042962d5ae3ab17866522a0d2d99e873624441.zip |
Whisper fix (#711)
* Remove unnecessary file.
* Whisper fix.
-rw-r--r-- | candle-examples/examples/whisper/main.rs | 5 | ||||
-rw-r--r-- | candle-examples/examples/whisper/multilingual.rs | 7 |
2 files changed, 3 insertions, 9 deletions
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index fc64d458..5dd8ee20 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -146,11 +146,8 @@ impl Decoder { tokens.push(language_token); } match self.task { - Some(Task::Transcribe) => tokens.push(self.transcribe_token), + None | Some(Task::Transcribe) => tokens.push(self.transcribe_token), Some(Task::Translate) => tokens.push(self.translate_token), - None => { - // Nothing in this case, same as the Python implementation. - } } if !self.timestamps { tokens.push(self.no_timestamps_token); diff --git a/candle-examples/examples/whisper/multilingual.rs b/candle-examples/examples/whisper/multilingual.rs index 3587b01a..bc0bae1f 100644 --- a/candle-examples/examples/whisper/multilingual.rs +++ b/candle-examples/examples/whisper/multilingual.rs @@ -117,11 +117,8 @@ pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor) let audio_features = model.encoder.forward(&mel, true)?; let tokens = Tensor::new(&[[sot_token]], device)?; let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?; - let logits = model - .decoder - .forward(&tokens, &audio_features, true)? - .i(0)? - .i(0)?; + let ys = model.decoder.forward(&tokens, &audio_features, true)?; + let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?; let logits = logits.index_select(&language_token_ids, 0)?; let probs = candle_nn::ops::softmax(&logits, D::Minus1)?; let probs = probs.to_vec1::<f32>()?; |