summaryrefslogtreecommitdiff
path: root/candle-examples/examples/t5/main.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-13 08:37:04 +0200
committerGitHub <noreply@github.com>2023-09-13 07:37:04 +0100
commite4553fb355ffebe6781ea2d35ba0734a310cab9b (patch)
treedac344f070aaad989055f51817fecc21eb34e7fb /candle-examples/examples/t5/main.rs
parentd801e1d564c5a6560680ff085e31dc4322627542 (diff)
downloadcandle-e4553fb355ffebe6781ea2d35ba0734a310cab9b.tar.gz
candle-e4553fb355ffebe6781ea2d35ba0734a310cab9b.tar.bz2
candle-e4553fb355ffebe6781ea2d35ba0734a310cab9b.zip
T5 tweaks (#831)
* Use default values rather than options. * Avoid exposing the device field. * More tweaks.
Diffstat (limited to 'candle-examples/examples/t5/main.rs')
-rw-r--r--candle-examples/examples/t5/main.rs16
1 files changed, 2 insertions, 14 deletions
diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs
index bcba846d..84be0204 100644
--- a/candle-examples/examples/t5/main.rs
+++ b/candle-examples/examples/t5/main.rs
@@ -30,7 +30,7 @@ struct Args {
#[arg(long)]
tracing: bool,
- /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
+ /// The model repository to use on the HuggingFace hub.
#[arg(long)]
model_id: Option<String>,
@@ -94,22 +94,10 @@ impl Args {
}
fn main() -> Result<()> {
- use tracing_chrome::ChromeLayerBuilder;
- use tracing_subscriber::prelude::*;
-
let args = Args::parse();
- let _guard = if args.tracing {
- println!("tracing...");
- let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
- tracing_subscriber::registry().with(chrome_layer).init();
- Some(guard)
- } else {
- None
- };
let start = std::time::Instant::now();
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
- let device = &model.device;
let prompt = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
let tokenizer = tokenizer
.with_padding(None)
@@ -120,7 +108,7 @@ fn main() -> Result<()> {
.map_err(E::msg)?
.get_ids()
.to_vec();
- let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
+ let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;
println!("Loaded and encoded {:?}", start.elapsed());
for idx in 0..args.n {
let start = std::time::Instant::now();