diff options
Diffstat (limited to 'candle-examples/examples/bert/main.rs')
-rw-r--r-- | candle-examples/examples/bert/main.rs | 114 |
1 files changed, 42 insertions, 72 deletions
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 3871c752..d0d600ee 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -109,14 +109,14 @@ impl Config { } } -fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?; +fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { + let embeddings = vb.get((vocab_size, hidden_size), "weight")?; Ok(Embedding::new(embeddings, hidden_size)) } -fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> { - let weight = vb.get((size2, size1), &format!("{p}.weight"))?; - let bias = vb.get(size2, &format!("{p}.bias"))?; +fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { + let weight = vb.get((size2, size1), "weight")?; + let bias = vb.get(size2, "bias")?; Ok(Linear::new(weight, Some(bias))) } @@ -135,17 +135,11 @@ impl Dropout { } } -fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<LayerNorm> { - let (weight, bias) = match ( - vb.get(size, &format!("{p}.weight")), - vb.get(size, &format!("{p}.bias")), - ) { +fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { + let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) { (Ok(weight), Ok(bias)) => (weight, bias), (Err(err), _) | (_, Err(err)) => { - if let (Ok(weight), Ok(bias)) = ( - vb.get(size, &format!("{p}.gamma")), - vb.get(size, &format!("{p}.beta")), - ) { + if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { (weight, bias) } else { return Err(err.into()); @@ -167,33 +161,29 @@ struct BertEmbeddings { } impl BertEmbeddings { - fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { let word_embeddings = embedding( config.vocab_size, config.hidden_size, - &format!("{p}.word_embeddings"), - vb, + vb.pp("word_embeddings"), )?; let position_embeddings = embedding( config.max_position_embeddings, config.hidden_size, - &format!("{p}.position_embeddings"), - vb, + vb.pp("position_embeddings"), )?; let token_type_embeddings = embedding( config.type_vocab_size, config.hidden_size, - &format!("{p}.token_type_embeddings"), - vb, + vb.pp("token_type_embeddings"), )?; let layer_norm = layer_norm( config.hidden_size, config.layer_norm_eps, - &format!("{p}.LayerNorm"), - vb, + vb.pp("LayerNorm"), )?; let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect(); - let position_ids = Tensor::new(&position_ids[..], &vb.device)?.unsqueeze(0)?; + let position_ids = Tensor::new(&position_ids[..], vb.device())?.unsqueeze(0)?; let token_type_ids = position_ids.zeros_like()?; Ok(Self { word_embeddings, @@ -233,14 +223,14 @@ struct BertSelfAttention { } impl BertSelfAttention { - fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { let attention_head_size = config.hidden_size / config.num_attention_heads; let all_head_size = config.num_attention_heads * attention_head_size; let dropout = Dropout::new(config.hidden_dropout_prob); let hidden_size = config.hidden_size; - let query = linear(hidden_size, all_head_size, &format!("{p}.query"), vb)?; - let value = linear(hidden_size, all_head_size, &format!("{p}.value"), vb)?; - let key = linear(hidden_size, all_head_size, &format!("{p}.key"), vb)?; + let query = linear(hidden_size, all_head_size, vb.pp("query"))?; + let value = linear(hidden_size, all_head_size, vb.pp("value"))?; + let key = linear(hidden_size, all_head_size, vb.pp("key"))?; Ok(Self { query, key, @@ -289,18 +279,12 @@ struct BertSelfOutput { } impl BertSelfOutput { - fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> { - let dense = linear( - config.hidden_size, - config.hidden_size, - &format!("{p}.dense"), - vb, - )?; + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?; let layer_norm = layer_norm( config.hidden_size, config.layer_norm_eps, - &format!("{p}.LayerNorm"), - vb, + vb.pp("LayerNorm"), )?; let dropout = Dropout::new(config.hidden_dropout_prob); Ok(Self { @@ -324,9 +308,9 @@ struct BertAttention { } impl BertAttention { - fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> { - let self_attention = BertSelfAttention::load(&format!("{p}.self"), vb, config)?; - let self_output = BertSelfOutput::load(&format!("{p}.output"), vb, config)?; + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let self_attention = BertSelfAttention::load(vb.pp("self"), config)?; + let self_output = BertSelfOutput::load(vb.pp("output"), config)?; Ok(Self { self_attention, self_output, @@ -347,13 +331,8 @@ struct BertIntermediate { } impl BertIntermediate { - fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> { - let dense = linear( - config.hidden_size, - config.intermediate_size, - &format!("{p}.dense"), - vb, - )?; + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?; Ok(Self { dense, intermediate_act: config.hidden_act, @@ -375,18 +354,12 @@ struct BertOutput { } impl BertOutput { - fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> { - let dense = linear( - config.intermediate_size, - config.hidden_size, - &format!("{p}.dense"), - vb, - )?; + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?; let layer_norm = layer_norm( config.hidden_size, config.layer_norm_eps, - &format!("{p}.LayerNorm"), - vb, + vb.pp("LayerNorm"), )?; let dropout = Dropout::new(config.hidden_dropout_prob); Ok(Self { @@ -411,10 +384,10 @@ struct BertLayer { } impl BertLayer { - fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> { - let attention = BertAttention::load(&format!("{p}.attention"), vb, config)?; - let intermediate = BertIntermediate::load(&format!("{p}.intermediate"), vb, config)?; - let output = BertOutput::load(&format!("{p}.output"), vb, config)?; + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let attention = BertAttention::load(vb.pp("attention"), config)?; + let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?; + let output = BertOutput::load(vb.pp("output"), config)?; Ok(Self { attention, intermediate, @@ -441,12 +414,9 @@ struct BertEncoder { } impl BertEncoder { - fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { let layers = (0..config.num_hidden_layers) - .map(|index| { - let p = format!("{p}.layer.{index}"); - BertLayer::load(&p, vb, config) - }) + .map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config)) .collect::<Result<Vec<_>>>()?; Ok(BertEncoder { layers }) } @@ -469,17 +439,17 @@ struct BertModel { } impl BertModel { - fn load(vb: &VarBuilder, config: &Config) -> Result<Self> { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { let (embeddings, encoder) = match ( - BertEmbeddings::load("embeddings", vb, config), - BertEncoder::load("encoder", vb, config), + BertEmbeddings::load(vb.pp("embeddings"), config), + BertEncoder::load(vb.pp("encoder"), config), ) { (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), (Err(err), _) | (_, Err(err)) => { if let Some(model_type) = &config.model_type { if let (Ok(embeddings), Ok(encoder)) = ( - BertEmbeddings::load(&format!("{model_type}.embeddings"), vb, config), - BertEncoder::load(&format!("{model_type}.encoder"), vb, config), + BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config), + BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config), ) { (embeddings, encoder) } else { @@ -493,7 +463,7 @@ impl BertModel { Ok(Self { embeddings, encoder, - device: vb.device.clone(), + device: vb.device().clone(), }) } @@ -576,7 +546,7 @@ impl Args { let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); - let model = BertModel::load(&vb, &config)?; + let model = BertModel::load(vb, &config)?; Ok((model, tokenizer)) } } |