summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/whisper/model.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-02 14:59:53 +0100
committerGitHub <noreply@github.com>2023-10-02 14:59:53 +0100
commite04c789230c609c285991b78c29f1d6eef0d104f (patch)
tree718a61d3838c7ac82b56cb5a202ee4b172465aa4 /candle-transformers/src/models/whisper/model.rs
parent263a1722021cdf24c801422c58887d93ad2e382a (diff)
downloadcandle-e04c789230c609c285991b78c29f1d6eef0d104f.tar.gz
candle-e04c789230c609c285991b78c29f1d6eef0d104f.tar.bz2
candle-e04c789230c609c285991b78c29f1d6eef0d104f.zip
Add a quantized variant of whisper (#1017)
* Add the quantized-whisper model. * Quantized the whisper model. * Adapt the whisper example to handle quantization. * Add the quantized flag. * Load the proper weights.
Diffstat (limited to 'candle-transformers/src/models/whisper/model.rs')
-rw-r--r--candle-transformers/src/models/whisper/model.rs19
1 files changed, 1 insertions, 18 deletions
diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs
index d2eda796..2a58afaf 100644
--- a/candle-transformers/src/models/whisper/model.rs
+++ b/candle-transformers/src/models/whisper/model.rs
@@ -1,23 +1,6 @@
+use super::Config;
use candle::{Device, IndexOp, Result, Tensor, D};
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
-use serde::Deserialize;
-
-// The names in comments correspond to the original implementation:
-// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17
-#[derive(Debug, Clone, PartialEq, Deserialize)]
-pub struct Config {
- pub num_mel_bins: usize, // n_mels
- pub max_source_positions: usize, // n_audio_ctx
- pub d_model: usize, // n_audio_state
- pub encoder_attention_heads: usize, // n_audio_head
- pub encoder_layers: usize, // n_audio_layer
- pub vocab_size: usize, // n_vocab
- pub max_target_positions: usize, // n_text_ctx
- // pub n_text_state: usize,
- pub decoder_attention_heads: usize, // n_text_head
- pub decoder_layers: usize, // n_text_layer
- pub suppress_tokens: Vec<u32>,
-}
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;