summaryrefslogtreecommitdiff
path: root/candle-examples/examples/t5/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/t5/main.rs')
-rw-r--r--candle-examples/examples/t5/main.rs22
1 files changed, 3 insertions, 19 deletions
diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs
index 6a446615..8ef108b6 100644
--- a/candle-examples/examples/t5/main.rs
+++ b/candle-examples/examples/t5/main.rs
@@ -96,25 +96,9 @@ impl T5ModelBuilder {
let api = api.repo(repo);
let config_filename = api.get("config.json")?;
let tokenizer_filename = api.get("tokenizer.json")?;
- let weights_filename = if model_id == "google/flan-t5-xxl" {
- vec![
- api.get("model-00001-of-00005.safetensors")?,
- api.get("model-00002-of-00005.safetensors")?,
- api.get("model-00003-of-00005.safetensors")?,
- api.get("model-00004-of-00005.safetensors")?,
- api.get("model-00005-of-00005.safetensors")?,
- ]
- } else if model_id == "google/flan-ul2" {
- vec![
- api.get("model-00001-of-00008.safetensors")?,
- api.get("model-00002-of-00008.safetensors")?,
- api.get("model-00003-of-00008.safetensors")?,
- api.get("model-00004-of-00008.safetensors")?,
- api.get("model-00005-of-00008.safetensors")?,
- api.get("model-00006-of-00008.safetensors")?,
- api.get("model-00007-of-00008.safetensors")?,
- api.get("model-00008-of-00008.safetensors")?,
- ]
+ let weights_filename = if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2"
+ {
+ candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
} else {
vec![api.get("model.safetensors")?]
};