summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/distilbert.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/distilbert.rs')
-rw-r--r--candle-transformers/src/models/distilbert.rs6
1 files changed, 3 insertions, 3 deletions
diff --git a/candle-transformers/src/models/distilbert.rs b/candle-transformers/src/models/distilbert.rs
index ea074c97..f899d772 100644
--- a/candle-transformers/src/models/distilbert.rs
+++ b/candle-transformers/src/models/distilbert.rs
@@ -275,7 +275,7 @@ struct Transformer {
impl Transformer {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let layers = (0..config.n_layers)
- .map(|index| TransformerBlock::load(vb.pp(&format!("layer.{index}")), config))
+ .map(|index| TransformerBlock::load(vb.pp(format!("layer.{index}")), config))
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "encoder");
Ok(Transformer { layers, span })
@@ -311,8 +311,8 @@ impl DistilBertModel {
(Err(err), _) | (_, Err(err)) => {
if let Some(model_type) = &config.model_type {
if let (Ok(embeddings), Ok(encoder)) = (
- Embeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
- Transformer::load(vb.pp(&format!("{model_type}.transformer")), config),
+ Embeddings::load(vb.pp(format!("{model_type}.embeddings")), config),
+ Transformer::load(vb.pp(format!("{model_type}.transformer")), config),
) {
(embeddings, encoder)
} else {