diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-02 14:59:53 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-02 14:59:53 +0100 |
commit | e04c789230c609c285991b78c29f1d6eef0d104f (patch) | |
tree | 718a61d3838c7ac82b56cb5a202ee4b172465aa4 /candle-transformers/src/models/whisper/model.rs | |
parent | 263a1722021cdf24c801422c58887d93ad2e382a (diff) | |
download | candle-e04c789230c609c285991b78c29f1d6eef0d104f.tar.gz candle-e04c789230c609c285991b78c29f1d6eef0d104f.tar.bz2 candle-e04c789230c609c285991b78c29f1d6eef0d104f.zip |
Add a quantized variant of whisper (#1017)
* Add the quantized-whisper model.
* Quantized the whisper model.
* Adapt the whisper example to handle quantization.
* Add the quantized flag.
* Load the proper weights.
Diffstat (limited to 'candle-transformers/src/models/whisper/model.rs')
-rw-r--r-- | candle-transformers/src/models/whisper/model.rs | 19 |
1 files changed, 1 insertions, 18 deletions
diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs index d2eda796..2a58afaf 100644 --- a/candle-transformers/src/models/whisper/model.rs +++ b/candle-transformers/src/models/whisper/model.rs @@ -1,23 +1,6 @@ +use super::Config; use candle::{Device, IndexOp, Result, Tensor, D}; use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; -use serde::Deserialize; - -// The names in comments correspond to the original implementation: -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17 -#[derive(Debug, Clone, PartialEq, Deserialize)] -pub struct Config { - pub num_mel_bins: usize, // n_mels - pub max_source_positions: usize, // n_audio_ctx - pub d_model: usize, // n_audio_state - pub encoder_attention_heads: usize, // n_audio_head - pub encoder_layers: usize, // n_audio_layer - pub vocab_size: usize, // n_vocab - pub max_target_positions: usize, // n_text_ctx - // pub n_text_state: usize, - pub decoder_attention_heads: usize, // n_text_head - pub decoder_layers: usize, // n_text_layer - pub suppress_tokens: Vec<u32>, -} fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { let embeddings = vb.get((vocab_size, hidden_size), "weight")?; |