summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/bert.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/bert.rs')
-rw-r--r--candle-transformers/src/models/bert.rs6
1 files changed, 3 insertions, 3 deletions
diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs
index 2262aa1a..354048de 100644
--- a/candle-transformers/src/models/bert.rs
+++ b/candle-transformers/src/models/bert.rs
@@ -419,7 +419,7 @@ struct BertEncoder {
impl BertEncoder {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let layers = (0..config.num_hidden_layers)
- .map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config))
+ .map(|index| BertLayer::load(vb.pp(format!("layer.{index}")), config))
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "encoder");
Ok(BertEncoder { layers, span })
@@ -454,8 +454,8 @@ impl BertModel {
(Err(err), _) | (_, Err(err)) => {
if let Some(model_type) = &config.model_type {
if let (Ok(embeddings), Ok(encoder)) = (
- BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
- BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config),
+ BertEmbeddings::load(vb.pp(format!("{model_type}.embeddings")), config),
+ BertEncoder::load(vb.pp(format!("{model_type}.encoder")), config),
) {
(embeddings, encoder)
} else {