summaryrefslogtreecommitdiff
path: root/candle-examples/examples/t5/main.rs
diff options
context:
space:
mode:
authorJuarez Bochi <juarez.bochi@grammarly.com>2023-11-09 12:55:09 -0500
committerGitHub <noreply@github.com>2023-11-09 18:55:09 +0100
commit18d30005c577c029dec611a0bdd0260946468cdb (patch)
treea465377a7671832df8dad41f64f10074d9656b44 /candle-examples/examples/t5/main.rs
parent695838432747b9c9460e74cd3f5086642b6897a9 (diff)
downloadcandle-18d30005c577c029dec611a0bdd0260946468cdb.tar.gz
candle-18d30005c577c029dec611a0bdd0260946468cdb.tar.bz2
candle-18d30005c577c029dec611a0bdd0260946468cdb.zip
Add support to UL2 model family (#1300)
* Add support to UL2 model family * Update docs with UL2 * Create ActivationWithOptionalGating to avoid polluting activations * Also refactor quantized t5 * Remove useless conversion * Revert Activation::NewGelu name change * Remove useless return * Apply rustfmt and clippy recommendations * Reuse t5::ActivationWithOptionalGating in quantized version * (cosmetic change) use a match rather than ifs + avoid early returns. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-examples/examples/t5/main.rs')
-rw-r--r--candle-examples/examples/t5/main.rs11
1 files changed, 11 insertions, 0 deletions
diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs
index f1c5a94b..6a446615 100644
--- a/candle-examples/examples/t5/main.rs
+++ b/candle-examples/examples/t5/main.rs
@@ -104,6 +104,17 @@ impl T5ModelBuilder {
api.get("model-00004-of-00005.safetensors")?,
api.get("model-00005-of-00005.safetensors")?,
]
+ } else if model_id == "google/flan-ul2" {
+ vec![
+ api.get("model-00001-of-00008.safetensors")?,
+ api.get("model-00002-of-00008.safetensors")?,
+ api.get("model-00003-of-00008.safetensors")?,
+ api.get("model-00004-of-00008.safetensors")?,
+ api.get("model-00005-of-00008.safetensors")?,
+ api.get("model-00006-of-00008.safetensors")?,
+ api.get("model-00007-of-00008.safetensors")?,
+ api.get("model-00008-of-00008.safetensors")?,
+ ]
} else {
vec![api.get("model.safetensors")?]
};