summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/whisper
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-27 17:59:19 +0200
committerGitHub <noreply@github.com>2023-10-27 16:59:19 +0100
commit85bea43e5b088b94612b0fd7ed8f09261dc79d52 (patch)
tree7e1575f5905bf3c413c8c39de1d45e04ba050584 /candle-transformers/src/models/whisper
parentb3181455d5bbebdcc358a48fd4d1e5ed38d78198 (diff)
downloadcandle-85bea43e5b088b94612b0fd7ed8f09261dc79d52.tar.gz
candle-85bea43e5b088b94612b0fd7ed8f09261dc79d52.tar.bz2
candle-85bea43e5b088b94612b0fd7ed8f09261dc79d52.zip
Make the whisper model cloneable (#1200)
* Add a quantized variant of llama2.c * Clippy fixes. * Make the whisper model cloneable.
Diffstat (limited to 'candle-transformers/src/models/whisper')
-rw-r--r--candle-transformers/src/models/whisper/model.rs7
-rw-r--r--candle-transformers/src/models/whisper/quantized_model.rs5
2 files changed, 11 insertions, 1 deletions
diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs
index 2a58afaf..6078944c 100644
--- a/candle-transformers/src/models/whisper/model.rs
+++ b/candle-transformers/src/models/whisper/model.rs
@@ -9,7 +9,7 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
//
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
// model.
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
@@ -53,6 +53,7 @@ fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
+#[derive(Debug, Clone)]
struct MultiHeadAttention {
query: Linear,
key: Linear,
@@ -162,6 +163,7 @@ impl MultiHeadAttention {
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
+#[derive(Debug, Clone)]
struct ResidualAttentionBlock {
attn: MultiHeadAttention,
attn_ln: LayerNorm,
@@ -241,6 +243,7 @@ fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
+#[derive(Debug, Clone)]
pub struct AudioEncoder {
conv1: Conv1d,
conv2: Conv1d,
@@ -316,6 +319,7 @@ impl AudioEncoder {
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
+#[derive(Debug, Clone)]
pub struct TextDecoder {
token_embedding: Embedding,
positional_embedding: Tensor,
@@ -380,6 +384,7 @@ impl TextDecoder {
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
+#[derive(Debug, Clone)]
pub struct Whisper {
pub encoder: AudioEncoder,
pub decoder: TextDecoder,
diff --git a/candle-transformers/src/models/whisper/quantized_model.rs b/candle-transformers/src/models/whisper/quantized_model.rs
index f0aead49..43ea4177 100644
--- a/candle-transformers/src/models/whisper/quantized_model.rs
+++ b/candle-transformers/src/models/whisper/quantized_model.rs
@@ -19,6 +19,7 @@ fn conv1d(
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
+#[derive(Debug, Clone)]
struct MultiHeadAttention {
query: Linear,
key: Linear,
@@ -128,6 +129,7 @@ impl MultiHeadAttention {
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
+#[derive(Debug, Clone)]
struct ResidualAttentionBlock {
attn: MultiHeadAttention,
attn_ln: LayerNorm,
@@ -206,6 +208,7 @@ fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
+#[derive(Debug, Clone)]
pub struct AudioEncoder {
conv1: Conv1d,
conv2: Conv1d,
@@ -281,6 +284,7 @@ impl AudioEncoder {
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
+#[derive(Debug, Clone)]
pub struct TextDecoder {
token_embedding: Embedding,
positional_embedding: Tensor,
@@ -347,6 +351,7 @@ impl TextDecoder {
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
+#[derive(Debug, Clone)]
pub struct Whisper {
pub encoder: AudioEncoder,
pub decoder: TextDecoder,