summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml2
-rw-r--r--candle-examples/examples/bert/main.rs7
-rw-r--r--candle-examples/examples/falcon/main.rs10
-rw-r--r--candle-examples/examples/llama/main.rs8
-rw-r--r--candle-examples/examples/whisper/main.rs19
5 files changed, 23 insertions, 23 deletions
diff --git a/Cargo.toml b/Cargo.toml
index 0dec835b..05c6240b 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -22,7 +22,7 @@ clap = { version = "4.2.4", features = ["derive"] }
cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas-bf16", features = ["f16"] }
# TODO: Switch back to the official gemm implementation if we manage to upstream the changes.
gemm = { git = "https://github.com/LaurentMazare/gemm.git" }
-hf-hub = "0.1.3"
+hf-hub = "0.2.0"
half = { version = "2.3.1", features = ["num-traits", "rand_distr"] }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
libc = { version = "0.2.147" }
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs
index 6672ad09..79c78968 100644
--- a/candle-examples/examples/bert/main.rs
+++ b/candle-examples/examples/bert/main.rs
@@ -69,10 +69,11 @@ impl Args {
)
} else {
let api = Api::new()?;
+ let api = api.repo(repo);
(
- api.get(&repo, "config.json")?,
- api.get(&repo, "tokenizer.json")?,
- api.get(&repo, "model.safetensors")?,
+ api.get("config.json")?,
+ api.get("tokenizer.json")?,
+ api.get("model.safetensors")?,
)
};
let config = std::fs::read_to_string(config_filename)?;
diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs
index 3a284c86..a01191a5 100644
--- a/candle-examples/examples/falcon/main.rs
+++ b/candle-examples/examples/falcon/main.rs
@@ -123,14 +123,18 @@ fn main() -> Result<()> {
let device = candle_examples::device(args.cpu)?;
let start = std::time::Instant::now();
let api = Api::new()?;
- let repo = Repo::with_revision(args.model_id, RepoType::Model, args.revision);
- let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
+ let repo = api.repo(Repo::with_revision(
+ args.model_id,
+ RepoType::Model,
+ args.revision,
+ ));
+ let tokenizer_filename = repo.get("tokenizer.json")?;
let mut filenames = vec![];
for rfilename in [
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
] {
- let filename = api.get(&repo, rfilename)?;
+ let filename = repo.get(rfilename)?;
filenames.push(filename);
}
println!("retrieved the files in {:?}", start.elapsed());
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index 582ac3f8..d9d1e21a 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -18,7 +18,7 @@ use clap::Parser;
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
-use hf_hub::{api::sync::Api, Repo, RepoType};
+use hf_hub::api::sync::Api;
mod model;
use model::{Config, Llama};
@@ -146,14 +146,14 @@ fn main() -> Result<()> {
}
});
println!("loading the model weights from {model_id}");
- let repo = Repo::new(model_id, RepoType::Model);
- let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
+ let api = api.model(model_id);
+ let tokenizer_filename = api.get("tokenizer.json")?;
let mut filenames = vec![];
for rfilename in [
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
] {
- let filename = api.get(&repo, rfilename)?;
+ let filename = api.get(rfilename)?;
filenames.push(filename);
}
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index 079424e3..c03779e7 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -282,28 +282,23 @@ fn main() -> Result<()> {
std::path::PathBuf::from(args.input.expect("You didn't specify a file to read from yet, are using a local model, please add `--input example.wav` to read some audio file")),
)
} else {
- let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let api = Api::new()?;
+ let dataset = api.dataset("Narsil/candle-examples".to_string());
+ let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let sample = if let Some(input) = args.input {
if let Some(sample) = input.strip_prefix("sample:") {
- api.get(
- &Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
- &format!("samples_{sample}.wav"),
- )?
+ dataset.get(&format!("samples_{sample}.wav"))?
} else {
std::path::PathBuf::from(input)
}
} else {
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
- api.get(
- &Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
- "samples_jfk.wav",
- )?
+ dataset.get("samples_jfk.wav")?
};
(
- api.get(&repo, "config.json")?,
- api.get(&repo, "tokenizer.json")?,
- api.get(&repo, "model.safetensors")?,
+ repo.get("config.json")?,
+ repo.get("tokenizer.json")?,
+ repo.get("model.safetensors")?,
sample,
)
};