summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/glm4/main.rs17
1 files changed, 10 insertions, 7 deletions
diff --git a/candle-examples/examples/glm4/main.rs b/candle-examples/examples/glm4/main.rs
index ced3841d..a6ba7c72 100644
--- a/candle-examples/examples/glm4/main.rs
+++ b/candle-examples/examples/glm4/main.rs
@@ -109,10 +109,10 @@ impl TextGeneration {
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
- /// Run on CPU rather than on GPU.
#[arg(name = "cache", short, long, default_value = ".")]
- cache_path: String,
+ cache_path: Option<String>,
+ /// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
@@ -178,11 +178,14 @@ fn main() -> anyhow::Result<()> {
);
let start = std::time::Instant::now();
- let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(
- args.cache_path.to_string().into(),
- ))
- .build()
- .map_err(anyhow::Error::msg)?;
+ let api = match args.cache_path.as_ref() {
+ None => hf_hub::api::sync::Api::new()?,
+ Some(path) => {
+ hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(path.to_string().into()))
+ .build()
+ .map_err(anyhow::Error::msg)?
+ }
+ };
let model_id = match args.model_id.as_ref() {
Some(model_id) => model_id.to_string(),