summaryrefslogtreecommitdiff
path: root/candle-examples/examples/whisper/model.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-10 22:37:34 +0100
committerGitHub <noreply@github.com>2023-07-10 22:37:34 +0100
commitb46c28a2ac2a88387590c65a2efef028f010b29e (patch)
treec43a08c46a537c6041f55fbe6cbbb105299aee9e /candle-examples/examples/whisper/model.rs
parent1aa7fbbc33ecd0ad3ce8698220d9eae434db50ba (diff)
downloadcandle-b46c28a2ac2a88387590c65a2efef028f010b29e.tar.gz
candle-b46c28a2ac2a88387590c65a2efef028f010b29e.tar.bz2
candle-b46c28a2ac2a88387590c65a2efef028f010b29e.zip
VarBuilder path creation (#131)
* Use a struct for the safetensor+routing. * Group the path and the var-builder together. * Fix for the empty path case.
Diffstat (limited to 'candle-examples/examples/whisper/model.rs')
-rw-r--r--candle-examples/examples/whisper/model.rs101
1 files changed, 42 insertions, 59 deletions
diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs
index d653d0c7..ece8b2d8 100644
--- a/candle-examples/examples/whisper/model.rs
+++ b/candle-examples/examples/whisper/model.rs
@@ -38,19 +38,19 @@ impl Config {
}
}
-fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Embedding> {
- let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
+fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
+ let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
-fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
- let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
- let bias = vb.get(size2, &format!("{p}.bias"))?;
+fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
+ let weight = vb.get((size2, size1), "weight")?;
+ let bias = vb.get(size2, "bias")?;
Ok(Linear::new(weight, Some(bias)))
}
-fn linear_no_bias(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
- let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
+fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
+ let weight = vb.get((size2, size1), "weight")?;
Ok(Linear::new(weight, None))
}
@@ -59,14 +59,10 @@ fn conv1d(
out_channels: usize,
kernel_size: usize,
config: Conv1dConfig,
- p: &str,
- vb: &VarBuilder,
+ vb: VarBuilder,
) -> Result<Conv1d> {
- let weight = vb.get(
- (out_channels, in_channels, kernel_size),
- &format!("{p}.weight"),
- )?;
- let bias = vb.get(out_channels, &format!("{p}.bias"))?;
+ let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
+ let bias = vb.get(out_channels, "bias")?;
Ok(Conv1d::new(weight, Some(bias), config))
}
@@ -75,13 +71,9 @@ fn conv1d_no_bias(
out_channels: usize,
kernel_size: usize,
config: Conv1dConfig,
- p: &str,
- vb: &VarBuilder,
+ vb: VarBuilder,
) -> Result<Conv1d> {
- let weight = vb.get(
- (out_channels, in_channels, kernel_size),
- &format!("{p}.weight"),
- )?;
+ let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
Ok(Conv1d::new(weight, None, config))
}
@@ -100,9 +92,9 @@ impl Dropout {
}
}
-fn layer_norm(size: usize, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
- let weight = vb.get(size, &format!("{p}.weight"))?;
- let bias = vb.get(size, &format!("{p}.bias"))?;
+fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
+ let weight = vb.get(size, "weight")?;
+ let bias = vb.get(size, "bias")?;
Ok(LayerNorm::new(weight, bias, 1e-5))
}
@@ -116,11 +108,11 @@ struct MultiHeadAttention {
}
impl MultiHeadAttention {
- fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
- let query = linear(n_state, n_state, &format!("{p}.q_proj"), vb)?;
- let value = linear(n_state, n_state, &format!("{p}.v_proj"), vb)?;
- let key = linear_no_bias(n_state, n_state, &format!("{p}.k_proj"), vb)?;
- let out = linear(n_state, n_state, &format!("{p}.out_proj"), vb)?;
+ fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
+ let query = linear(n_state, n_state, vb.pp("q_proj"))?;
+ let value = linear(n_state, n_state, vb.pp("v_proj"))?;
+ let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
+ let out = linear(n_state, n_state, vb.pp("out_proj"))?;
Ok(Self {
query,
key,
@@ -179,21 +171,20 @@ struct ResidualAttentionBlock {
}
impl ResidualAttentionBlock {
- fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
- let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.self_attn"), vb)?;
- let attn_ln = layer_norm(n_state, &format!("{p}.self_attn_layer_norm"), vb)?;
+ fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
+ let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
+ let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
let cross_attn = if ca {
- let cross_attn =
- MultiHeadAttention::load(n_state, n_head, &format!("{p}.encoder_attn"), vb)?;
- let cross_attn_ln = layer_norm(n_state, &format!("{p}.encoder_attn_layer_norm"), vb)?;
+ let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?;
+ let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?;
Some((cross_attn, cross_attn_ln))
} else {
None
};
let n_mlp = n_state * 4;
- let mlp_linear1 = linear(n_state, n_mlp, &format!("{p}.fc1"), vb)?;
- let mlp_linear2 = linear(n_mlp, n_state, &format!("{p}.fc2"), vb)?;
- let mlp_ln = layer_norm(n_state, &format!("{p}.final_layer_norm"), vb)?;
+ let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?;
+ let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?;
+ let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?;
Ok(Self {
attn,
attn_ln,
@@ -245,7 +236,7 @@ pub struct AudioEncoder {
}
impl AudioEncoder {
- fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let n_state = cfg.d_model;
let n_head = cfg.encoder_attention_heads;
let n_ctx = cfg.max_source_positions;
@@ -257,22 +248,15 @@ impl AudioEncoder {
padding: 1,
stride: 2,
};
- let conv1 = conv1d(
- cfg.num_mel_bins,
- n_state,
- 3,
- cfg1,
- &format!("{p}.conv1"),
- vb,
- )?;
- let conv2 = conv1d(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?;
- let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(&vb.device)?;
+ let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
+ let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
+ let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?;
let blocks = (0..cfg.encoder_layers)
.map(|i| {
- ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.layers.{i}"), vb)
+ ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}")))
})
.collect::<Result<Vec<_>>>()?;
- let ln_post = layer_norm(n_state, &format!("{p}.layer_norm"), vb)?;
+ let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?;
Ok(Self {
conv1,
conv2,
@@ -306,23 +290,22 @@ pub struct TextDecoder {
}
impl TextDecoder {
- fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let n_state = cfg.d_model;
let n_head = cfg.decoder_attention_heads;
let n_ctx = cfg.max_target_positions;
- let token_embedding = embedding(cfg.vocab_size, n_state, &format!("{p}.embed_tokens"), vb)?;
- let positional_embedding =
- vb.get((n_ctx, n_state), &format!("{p}.embed_positions.weight"))?;
+ let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?;
+ let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?;
let blocks = (0..cfg.decoder_layers)
.map(|i| {
- ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.layers.{i}"), vb)
+ ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}")))
})
.collect::<Result<Vec<_>>>()?;
- let ln = layer_norm(n_state, &format!("{p}.layer_norm"), vb)?;
+ let ln = layer_norm(n_state, vb.pp("layer_norm"))?;
let mask: Vec<_> = (0..n_ctx)
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
.collect();
- let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), &vb.device)?;
+ let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
Ok(Self {
token_embedding,
@@ -361,8 +344,8 @@ pub struct Whisper {
impl Whisper {
pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> {
- let encoder = AudioEncoder::load("model.encoder", vb, &config)?;
- let decoder = TextDecoder::load("model.decoder", vb, &config)?;
+ let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?;
+ let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?;
Ok(Self {
encoder,
decoder,