summaryrefslogtreecommitdiff
path: root/candle-examples/examples
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples')
-rw-r--r--candle-examples/examples/mistral/main.rs5
-rw-r--r--candle-examples/examples/mixtral/main.rs24
-rw-r--r--candle-examples/examples/phi/main.rs8
-rw-r--r--candle-examples/examples/t5/main.rs22
-rw-r--r--candle-examples/examples/yi/main.rs16
5 files changed, 10 insertions, 65 deletions
diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs
index 2b31142e..5ed5e5cb 100644
--- a/candle-examples/examples/mistral/main.rs
+++ b/candle-examples/examples/mistral/main.rs
@@ -235,10 +235,7 @@ fn main() -> Result<()> {
if args.quantized {
vec![repo.get("model-q4k.gguf")?]
} else {
- vec![
- repo.get("model-00001-of-00002.safetensors")?,
- repo.get("model-00002-of-00002.safetensors")?,
- ]
+ candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
}
};
diff --git a/candle-examples/examples/mixtral/main.rs b/candle-examples/examples/mixtral/main.rs
index fcde03c1..1b1a4b36 100644
--- a/candle-examples/examples/mixtral/main.rs
+++ b/candle-examples/examples/mixtral/main.rs
@@ -209,29 +209,7 @@ fn main() -> Result<()> {
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
- None => {
- vec![
- repo.get("model-00001-of-00019.safetensors")?,
- repo.get("model-00002-of-00019.safetensors")?,
- repo.get("model-00003-of-00019.safetensors")?,
- repo.get("model-00004-of-00019.safetensors")?,
- repo.get("model-00005-of-00019.safetensors")?,
- repo.get("model-00006-of-00019.safetensors")?,
- repo.get("model-00007-of-00019.safetensors")?,
- repo.get("model-00008-of-00019.safetensors")?,
- repo.get("model-00009-of-00019.safetensors")?,
- repo.get("model-00010-of-00019.safetensors")?,
- repo.get("model-00011-of-00019.safetensors")?,
- repo.get("model-00012-of-00019.safetensors")?,
- repo.get("model-00013-of-00019.safetensors")?,
- repo.get("model-00014-of-00019.safetensors")?,
- repo.get("model-00015-of-00019.safetensors")?,
- repo.get("model-00016-of-00019.safetensors")?,
- repo.get("model-00017-of-00019.safetensors")?,
- repo.get("model-00018-of-00019.safetensors")?,
- repo.get("model-00019-of-00019.safetensors")?,
- ]
- }
+ None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs
index 3574b1f2..c529867b 100644
--- a/candle-examples/examples/phi/main.rs
+++ b/candle-examples/examples/phi/main.rs
@@ -278,10 +278,10 @@ fn main() -> Result<()> {
} else {
match args.model {
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
- WhichModel::V2 => vec![
- repo.get("model-00001-of-00002.safetensors")?,
- repo.get("model-00002-of-00002.safetensors")?,
- ],
+ WhichModel::V2 => candle_examples::hub_load_safetensors(
+ &repo,
+ "model.safetensors.index.json",
+ )?,
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2.safetensors")?],
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B.safetensors")?],
}
diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs
index 6a446615..8ef108b6 100644
--- a/candle-examples/examples/t5/main.rs
+++ b/candle-examples/examples/t5/main.rs
@@ -96,25 +96,9 @@ impl T5ModelBuilder {
let api = api.repo(repo);
let config_filename = api.get("config.json")?;
let tokenizer_filename = api.get("tokenizer.json")?;
- let weights_filename = if model_id == "google/flan-t5-xxl" {
- vec![
- api.get("model-00001-of-00005.safetensors")?,
- api.get("model-00002-of-00005.safetensors")?,
- api.get("model-00003-of-00005.safetensors")?,
- api.get("model-00004-of-00005.safetensors")?,
- api.get("model-00005-of-00005.safetensors")?,
- ]
- } else if model_id == "google/flan-ul2" {
- vec![
- api.get("model-00001-of-00008.safetensors")?,
- api.get("model-00002-of-00008.safetensors")?,
- api.get("model-00003-of-00008.safetensors")?,
- api.get("model-00004-of-00008.safetensors")?,
- api.get("model-00005-of-00008.safetensors")?,
- api.get("model-00006-of-00008.safetensors")?,
- api.get("model-00007-of-00008.safetensors")?,
- api.get("model-00008-of-00008.safetensors")?,
- ]
+ let weights_filename = if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2"
+ {
+ candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
} else {
vec![api.get("model.safetensors")?]
};
diff --git a/candle-examples/examples/yi/main.rs b/candle-examples/examples/yi/main.rs
index a7184db9..e4cbfc6f 100644
--- a/candle-examples/examples/yi/main.rs
+++ b/candle-examples/examples/yi/main.rs
@@ -218,21 +218,7 @@ fn main() -> Result<()> {
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
- None => match args.which {
- Which::L6b => vec![
- repo.get("model-00001-of-00002.safetensors")?,
- repo.get("model-00002-of-00002.safetensors")?,
- ],
- Which::L34b => vec![
- repo.get("model-00001-of-00007.safetensors")?,
- repo.get("model-00002-of-00007.safetensors")?,
- repo.get("model-00003-of-00007.safetensors")?,
- repo.get("model-00004-of-00007.safetensors")?,
- repo.get("model-00005-of-00007.safetensors")?,
- repo.get("model-00006-of-00007.safetensors")?,
- repo.get("model-00007-of-00007.safetensors")?,
- ],
- },
+ None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;