diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-10 22:37:34 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-10 22:37:34 +0100 |
commit | b46c28a2ac2a88387590c65a2efef028f010b29e (patch) | |
tree | c43a08c46a537c6041f55fbe6cbbb105299aee9e /candle-examples/examples/whisper/model.rs | |
parent | 1aa7fbbc33ecd0ad3ce8698220d9eae434db50ba (diff) | |
download | candle-b46c28a2ac2a88387590c65a2efef028f010b29e.tar.gz candle-b46c28a2ac2a88387590c65a2efef028f010b29e.tar.bz2 candle-b46c28a2ac2a88387590c65a2efef028f010b29e.zip |
VarBuilder path creation (#131)
* Use a struct for the safetensor+routing.
* Group the path and the var-builder together.
* Fix for the empty path case.
Diffstat (limited to 'candle-examples/examples/whisper/model.rs')
-rw-r--r-- | candle-examples/examples/whisper/model.rs | 101 |
1 files changed, 42 insertions, 59 deletions
diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs index d653d0c7..ece8b2d8 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-examples/examples/whisper/model.rs @@ -38,19 +38,19 @@ impl Config { } } -fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?; +fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { + let embeddings = vb.get((vocab_size, hidden_size), "weight")?; Ok(Embedding::new(embeddings, hidden_size)) } -fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> { - let weight = vb.get((size2, size1), &format!("{p}.weight"))?; - let bias = vb.get(size2, &format!("{p}.bias"))?; +fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { + let weight = vb.get((size2, size1), "weight")?; + let bias = vb.get(size2, "bias")?; Ok(Linear::new(weight, Some(bias))) } -fn linear_no_bias(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> { - let weight = vb.get((size2, size1), &format!("{p}.weight"))?; +fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { + let weight = vb.get((size2, size1), "weight")?; Ok(Linear::new(weight, None)) } @@ -59,14 +59,10 @@ fn conv1d( out_channels: usize, kernel_size: usize, config: Conv1dConfig, - p: &str, - vb: &VarBuilder, + vb: VarBuilder, ) -> Result<Conv1d> { - let weight = vb.get( - (out_channels, in_channels, kernel_size), - &format!("{p}.weight"), - )?; - let bias = vb.get(out_channels, &format!("{p}.bias"))?; + let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?; + let bias = vb.get(out_channels, "bias")?; Ok(Conv1d::new(weight, Some(bias), config)) } @@ -75,13 +71,9 @@ fn conv1d_no_bias( out_channels: usize, kernel_size: usize, config: Conv1dConfig, - p: &str, - vb: &VarBuilder, + vb: VarBuilder, ) -> Result<Conv1d> { - let weight = vb.get( - (out_channels, in_channels, kernel_size), - &format!("{p}.weight"), - )?; + let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?; Ok(Conv1d::new(weight, None, config)) } @@ -100,9 +92,9 @@ impl Dropout { } } -fn layer_norm(size: usize, p: &str, vb: &VarBuilder) -> Result<LayerNorm> { - let weight = vb.get(size, &format!("{p}.weight"))?; - let bias = vb.get(size, &format!("{p}.bias"))?; +fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> { + let weight = vb.get(size, "weight")?; + let bias = vb.get(size, "bias")?; Ok(LayerNorm::new(weight, bias, 1e-5)) } @@ -116,11 +108,11 @@ struct MultiHeadAttention { } impl MultiHeadAttention { - fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result<Self> { - let query = linear(n_state, n_state, &format!("{p}.q_proj"), vb)?; - let value = linear(n_state, n_state, &format!("{p}.v_proj"), vb)?; - let key = linear_no_bias(n_state, n_state, &format!("{p}.k_proj"), vb)?; - let out = linear(n_state, n_state, &format!("{p}.out_proj"), vb)?; + fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> { + let query = linear(n_state, n_state, vb.pp("q_proj"))?; + let value = linear(n_state, n_state, vb.pp("v_proj"))?; + let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?; + let out = linear(n_state, n_state, vb.pp("out_proj"))?; Ok(Self { query, key, @@ -179,21 +171,20 @@ struct ResidualAttentionBlock { } impl ResidualAttentionBlock { - fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result<Self> { - let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.self_attn"), vb)?; - let attn_ln = layer_norm(n_state, &format!("{p}.self_attn_layer_norm"), vb)?; + fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> { + let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?; + let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?; let cross_attn = if ca { - let cross_attn = - MultiHeadAttention::load(n_state, n_head, &format!("{p}.encoder_attn"), vb)?; - let cross_attn_ln = layer_norm(n_state, &format!("{p}.encoder_attn_layer_norm"), vb)?; + let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?; + let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?; Some((cross_attn, cross_attn_ln)) } else { None }; let n_mlp = n_state * 4; - let mlp_linear1 = linear(n_state, n_mlp, &format!("{p}.fc1"), vb)?; - let mlp_linear2 = linear(n_mlp, n_state, &format!("{p}.fc2"), vb)?; - let mlp_ln = layer_norm(n_state, &format!("{p}.final_layer_norm"), vb)?; + let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?; + let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?; + let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?; Ok(Self { attn, attn_ln, @@ -245,7 +236,7 @@ pub struct AudioEncoder { } impl AudioEncoder { - fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { let n_state = cfg.d_model; let n_head = cfg.encoder_attention_heads; let n_ctx = cfg.max_source_positions; @@ -257,22 +248,15 @@ impl AudioEncoder { padding: 1, stride: 2, }; - let conv1 = conv1d( - cfg.num_mel_bins, - n_state, - 3, - cfg1, - &format!("{p}.conv1"), - vb, - )?; - let conv2 = conv1d(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?; - let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(&vb.device)?; + let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; + let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; + let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?; let blocks = (0..cfg.encoder_layers) .map(|i| { - ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.layers.{i}"), vb) + ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}"))) }) .collect::<Result<Vec<_>>>()?; - let ln_post = layer_norm(n_state, &format!("{p}.layer_norm"), vb)?; + let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?; Ok(Self { conv1, conv2, @@ -306,23 +290,22 @@ pub struct TextDecoder { } impl TextDecoder { - fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { let n_state = cfg.d_model; let n_head = cfg.decoder_attention_heads; let n_ctx = cfg.max_target_positions; - let token_embedding = embedding(cfg.vocab_size, n_state, &format!("{p}.embed_tokens"), vb)?; - let positional_embedding = - vb.get((n_ctx, n_state), &format!("{p}.embed_positions.weight"))?; + let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?; + let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?; let blocks = (0..cfg.decoder_layers) .map(|i| { - ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.layers.{i}"), vb) + ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}"))) }) .collect::<Result<Vec<_>>>()?; - let ln = layer_norm(n_state, &format!("{p}.layer_norm"), vb)?; + let ln = layer_norm(n_state, vb.pp("layer_norm"))?; let mask: Vec<_> = (0..n_ctx) .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) .collect(); - let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), &vb.device)?; + let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?; Ok(Self { token_embedding, @@ -361,8 +344,8 @@ pub struct Whisper { impl Whisper { pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> { - let encoder = AudioEncoder::load("model.encoder", vb, &config)?; - let decoder = TextDecoder::load("model.decoder", vb, &config)?; + let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?; + let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?; Ok(Self { encoder, decoder, |