diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-13 08:37:04 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-13 07:37:04 +0100 |
commit | e4553fb355ffebe6781ea2d35ba0734a310cab9b (patch) | |
tree | dac344f070aaad989055f51817fecc21eb34e7fb /candle-examples/examples/t5/main.rs | |
parent | d801e1d564c5a6560680ff085e31dc4322627542 (diff) | |
download | candle-e4553fb355ffebe6781ea2d35ba0734a310cab9b.tar.gz candle-e4553fb355ffebe6781ea2d35ba0734a310cab9b.tar.bz2 candle-e4553fb355ffebe6781ea2d35ba0734a310cab9b.zip |
T5 tweaks (#831)
* Use default values rather than options.
* Avoid exposing the device field.
* More tweaks.
Diffstat (limited to 'candle-examples/examples/t5/main.rs')
-rw-r--r-- | candle-examples/examples/t5/main.rs | 16 |
1 files changed, 2 insertions, 14 deletions
diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs index bcba846d..84be0204 100644 --- a/candle-examples/examples/t5/main.rs +++ b/candle-examples/examples/t5/main.rs @@ -30,7 +30,7 @@ struct Args { #[arg(long)] tracing: bool, - /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending + /// The model repository to use on the HuggingFace hub. #[arg(long)] model_id: Option<String>, @@ -94,22 +94,10 @@ impl Args { } fn main() -> Result<()> { - use tracing_chrome::ChromeLayerBuilder; - use tracing_subscriber::prelude::*; - let args = Args::parse(); - let _guard = if args.tracing { - println!("tracing..."); - let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); - tracing_subscriber::registry().with(chrome_layer).init(); - Some(guard) - } else { - None - }; let start = std::time::Instant::now(); let (model, mut tokenizer) = args.build_model_and_tokenizer()?; - let device = &model.device; let prompt = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string()); let tokenizer = tokenizer .with_padding(None) @@ -120,7 +108,7 @@ fn main() -> Result<()> { .map_err(E::msg)? .get_ids() .to_vec(); - let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; + let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?; println!("Loaded and encoded {:?}", start.elapsed()); for idx in 0..args.n { let start = std::time::Instant::now(); |