diff options
Diffstat (limited to 'candle-transformers/src/models/distilbert.rs')
-rw-r--r-- | candle-transformers/src/models/distilbert.rs | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/candle-transformers/src/models/distilbert.rs b/candle-transformers/src/models/distilbert.rs index ea074c97..f899d772 100644 --- a/candle-transformers/src/models/distilbert.rs +++ b/candle-transformers/src/models/distilbert.rs @@ -275,7 +275,7 @@ struct Transformer { impl Transformer { fn load(vb: VarBuilder, config: &Config) -> Result<Self> { let layers = (0..config.n_layers) - .map(|index| TransformerBlock::load(vb.pp(&format!("layer.{index}")), config)) + .map(|index| TransformerBlock::load(vb.pp(format!("layer.{index}")), config)) .collect::<Result<Vec<_>>>()?; let span = tracing::span!(tracing::Level::TRACE, "encoder"); Ok(Transformer { layers, span }) @@ -311,8 +311,8 @@ impl DistilBertModel { (Err(err), _) | (_, Err(err)) => { if let Some(model_type) = &config.model_type { if let (Ok(embeddings), Ok(encoder)) = ( - Embeddings::load(vb.pp(&format!("{model_type}.embeddings")), config), - Transformer::load(vb.pp(&format!("{model_type}.transformer")), config), + Embeddings::load(vb.pp(format!("{model_type}.embeddings")), config), + Transformer::load(vb.pp(format!("{model_type}.transformer")), config), ) { (embeddings, encoder) } else { |