summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/llama/main.rs')
-rw-r--r--candle-examples/examples/llama/main.rs8
1 files changed, 4 insertions, 4 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index 582ac3f8..d9d1e21a 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -18,7 +18,7 @@ use clap::Parser;
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
-use hf_hub::{api::sync::Api, Repo, RepoType};
+use hf_hub::api::sync::Api;
mod model;
use model::{Config, Llama};
@@ -146,14 +146,14 @@ fn main() -> Result<()> {
}
});
println!("loading the model weights from {model_id}");
- let repo = Repo::new(model_id, RepoType::Model);
- let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
+ let api = api.model(model_id);
+ let tokenizer_filename = api.get("tokenizer.json")?;
let mut filenames = vec![];
for rfilename in [
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
] {
- let filename = api.get(&repo, rfilename)?;
+ let filename = api.get(rfilename)?;
filenames.push(filename);
}