diff options
Diffstat (limited to 'candle-examples/examples/t5')
-rw-r--r-- | candle-examples/examples/t5/main.rs | 22 |
1 files changed, 3 insertions, 19 deletions
diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs index 6a446615..8ef108b6 100644 --- a/candle-examples/examples/t5/main.rs +++ b/candle-examples/examples/t5/main.rs @@ -96,25 +96,9 @@ impl T5ModelBuilder { let api = api.repo(repo); let config_filename = api.get("config.json")?; let tokenizer_filename = api.get("tokenizer.json")?; - let weights_filename = if model_id == "google/flan-t5-xxl" { - vec![ - api.get("model-00001-of-00005.safetensors")?, - api.get("model-00002-of-00005.safetensors")?, - api.get("model-00003-of-00005.safetensors")?, - api.get("model-00004-of-00005.safetensors")?, - api.get("model-00005-of-00005.safetensors")?, - ] - } else if model_id == "google/flan-ul2" { - vec![ - api.get("model-00001-of-00008.safetensors")?, - api.get("model-00002-of-00008.safetensors")?, - api.get("model-00003-of-00008.safetensors")?, - api.get("model-00004-of-00008.safetensors")?, - api.get("model-00005-of-00008.safetensors")?, - api.get("model-00006-of-00008.safetensors")?, - api.get("model-00007-of-00008.safetensors")?, - api.get("model-00008-of-00008.safetensors")?, - ] + let weights_filename = if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2" + { + candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")? } else { vec![api.get("model.safetensors")?] }; |