diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-27 17:59:19 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-27 16:59:19 +0100 |
commit | 85bea43e5b088b94612b0fd7ed8f09261dc79d52 (patch) | |
tree | 7e1575f5905bf3c413c8c39de1d45e04ba050584 /candle-transformers/src/models/whisper | |
parent | b3181455d5bbebdcc358a48fd4d1e5ed38d78198 (diff) | |
download | candle-85bea43e5b088b94612b0fd7ed8f09261dc79d52.tar.gz candle-85bea43e5b088b94612b0fd7ed8f09261dc79d52.tar.bz2 candle-85bea43e5b088b94612b0fd7ed8f09261dc79d52.zip |
Make the whisper model cloneable (#1200)
* Add a quantized variant of llama2.c
* Clippy fixes.
* Make the whisper model cloneable.
Diffstat (limited to 'candle-transformers/src/models/whisper')
-rw-r--r-- | candle-transformers/src/models/whisper/model.rs | 7 | ||||
-rw-r--r-- | candle-transformers/src/models/whisper/quantized_model.rs | 5 |
2 files changed, 11 insertions, 1 deletions
diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs index 2a58afaf..6078944c 100644 --- a/candle-transformers/src/models/whisper/model.rs +++ b/candle-transformers/src/models/whisper/model.rs @@ -9,7 +9,7 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em // // We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting // model. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Linear { inner: candle_nn::Linear, span: tracing::Span, @@ -53,6 +53,7 @@ fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62 +#[derive(Debug, Clone)] struct MultiHeadAttention { query: Linear, key: Linear, @@ -162,6 +163,7 @@ impl MultiHeadAttention { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111 +#[derive(Debug, Clone)] struct ResidualAttentionBlock { attn: MultiHeadAttention, attn_ln: LayerNorm, @@ -241,6 +243,7 @@ fn sinusoids(length: usize, channels: usize) -> Result<Tensor> { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143 +#[derive(Debug, Clone)] pub struct AudioEncoder { conv1: Conv1d, conv2: Conv1d, @@ -316,6 +319,7 @@ impl AudioEncoder { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176 +#[derive(Debug, Clone)] pub struct TextDecoder { token_embedding: Embedding, positional_embedding: Tensor, @@ -380,6 +384,7 @@ impl TextDecoder { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221 +#[derive(Debug, Clone)] pub struct Whisper { pub encoder: AudioEncoder, pub decoder: TextDecoder, diff --git a/candle-transformers/src/models/whisper/quantized_model.rs b/candle-transformers/src/models/whisper/quantized_model.rs index f0aead49..43ea4177 100644 --- a/candle-transformers/src/models/whisper/quantized_model.rs +++ b/candle-transformers/src/models/whisper/quantized_model.rs @@ -19,6 +19,7 @@ fn conv1d( } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62 +#[derive(Debug, Clone)] struct MultiHeadAttention { query: Linear, key: Linear, @@ -128,6 +129,7 @@ impl MultiHeadAttention { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111 +#[derive(Debug, Clone)] struct ResidualAttentionBlock { attn: MultiHeadAttention, attn_ln: LayerNorm, @@ -206,6 +208,7 @@ fn sinusoids(length: usize, channels: usize) -> Result<Tensor> { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143 +#[derive(Debug, Clone)] pub struct AudioEncoder { conv1: Conv1d, conv2: Conv1d, @@ -281,6 +284,7 @@ impl AudioEncoder { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176 +#[derive(Debug, Clone)] pub struct TextDecoder { token_embedding: Embedding, positional_embedding: Tensor, @@ -347,6 +351,7 @@ impl TextDecoder { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221 +#[derive(Debug, Clone)] pub struct Whisper { pub encoder: AudioEncoder, pub decoder: TextDecoder, |