summaryrefslogtreecommitdiff
path: root/candle-examples/examples/bert/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/bert/main.rs')
-rw-r--r--candle-examples/examples/bert/main.rs47
1 files changed, 20 insertions, 27 deletions
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs
index 70592013..fcd2eab9 100644
--- a/candle-examples/examples/bert/main.rs
+++ b/candle-examples/examples/bert/main.rs
@@ -5,11 +5,11 @@ extern crate intel_mkl_src;
extern crate accelerate_src;
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
-use anyhow::{anyhow, Error as E, Result};
+use anyhow::{Error as E, Result};
use candle::Tensor;
use candle_nn::VarBuilder;
use clap::Parser;
-use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
+use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{PaddingParams, Tokenizer};
#[derive(Parser, Debug)]
@@ -19,10 +19,6 @@ struct Args {
#[arg(long)]
cpu: bool,
- /// Run offline (you must have the files already cached)
- #[arg(long)]
- offline: bool,
-
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
@@ -38,6 +34,10 @@ struct Args {
#[arg(long)]
prompt: Option<String>,
+ /// Use the pytorch weights rather than the safetensors ones
+ #[arg(long)]
+ use_pth: bool,
+
/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,
@@ -60,34 +60,27 @@ impl Args {
};
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
- let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
- let cache = Cache::default().repo(repo);
- (
- cache
- .get("config.json")
- .ok_or(anyhow!("Missing config file in cache"))?,
- cache
- .get("tokenizer.json")
- .ok_or(anyhow!("Missing tokenizer file in cache"))?,
- cache
- .get("model.safetensors")
- .ok_or(anyhow!("Missing weights file in cache"))?,
- )
- } else {
+ let (config_filename, tokenizer_filename, weights_filename) = {
let api = Api::new()?;
let api = api.repo(repo);
- (
- api.get("config.json")?,
- api.get("tokenizer.json")?,
- api.get("model.safetensors")?,
- )
+ let config = api.get("config.json")?;
+ let tokenizer = api.get("tokenizer.json")?;
+ let weights = if self.use_pth {
+ api.get("pytorch_model.bin")?
+ } else {
+ api.get("model.safetensors")?
+ };
+ (config, tokenizer, weights)
};
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
- let vb =
- unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? };
+ let vb = if self.use_pth {
+ VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
+ } else {
+ unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
+ };
let model = BertModel::load(vb, &config)?;
Ok((model, tokenizer))
}