diff options
Diffstat (limited to 'candle-examples/examples')
-rw-r--r-- | candle-examples/examples/t5/README.md | 34 | ||||
-rw-r--r-- | candle-examples/examples/t5/main.rs | 104 |
2 files changed, 106 insertions, 32 deletions
diff --git a/candle-examples/examples/t5/README.md b/candle-examples/examples/t5/README.md index 66952395..c6ea2125 100644 --- a/candle-examples/examples/t5/README.md +++ b/candle-examples/examples/t5/README.md @@ -1,17 +1,25 @@ # candle-t5 -Generates embeddings using a T5 model. It doesn't support generation yet. +## Encoder-decoder example: ```bash -$ cargo run --example t5 -- --model-id t5-large --prompt 'how tall is obama' --n 1 -Loaded and encoded 2.014244792s -[[[-0.3174, -0.1462, 0.0065, ..., -0.0579, -0.0581, 0.1387], - [-0.2905, -0.1945, -0.0685, ..., -0.2457, -0.5137, -0.1760], - [-0.0591, -0.0213, -0.0241, ..., -0.0210, 0.0491, -0.0300], - ... - [-0.4333, 0.0027, -0.0609, ..., 0.3069, -0.2252, 0.3306], - [-0.1458, 0.1323, -0.0138, ..., 0.3000, -0.4550, -0.0384], - [ 0.0397, 0.0485, -0.2373, ..., 0.2578, -0.2650, -0.4356]]] -Tensor[[1, 9, 1024], f32] -Took 2.1363425s -```
\ No newline at end of file +$ cargo run --example t5 -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode +... +Running on CPU, to run on GPU, build this example with `--features cuda` + Eine schöne Kerze. +9 tokens generated (2.42 token/s) +``` + +## Sentence embedding example: + +```bash +$ cargo run --example t5 -- --model-id "t5-small" --prompt "A beautiful candle." +... +[[[ 0.0515, -0.0541, -0.0761, ..., -0.0392, 0.1511, -0.0265], + [-0.0974, 0.0998, -0.1659, ..., -0.2450, 0.1738, -0.0164], + [ 0.0624, -0.1024, 0.0430, ..., -0.1388, 0.0564, -0.2962], + [-0.0389, -0.1173, 0.0026, ..., 0.1064, -0.1065, 0.0990], + [ 0.1300, 0.0027, -0.0326, ..., 0.0026, -0.0317, 0.0851]]] +Tensor[[1, 5, 512], f32] +Took 303.766583ms +``` diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs index 1e182974..00291609 100644 --- a/candle-examples/examples/t5/main.rs +++ b/candle-examples/examples/t5/main.rs @@ -3,18 +3,22 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; +use std::io::Write; +use std::path::PathBuf; + use candle_transformers::models::t5; use anyhow::{anyhow, Error as E, Result}; -use candle::{DType, Tensor}; +use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; use clap::Parser; use hf_hub::{api::sync::Api, Cache, Repo, RepoType}; use tokenizers::Tokenizer; const DTYPE: DType = DType::F32; -#[derive(Parser, Debug)] +#[derive(Parser, Debug, Clone)] #[command(author, version, about, long_about = None)] struct Args { /// Run on CPU rather than on GPU. @@ -36,7 +40,11 @@ struct Args { #[arg(long)] revision: Option<String>, - /// Compute embeddings for this prompt, otherwise compute sentence similarities. + /// Enable decoding. + #[arg(long)] + decode: bool, + + /// Use this prompt, otherwise compute sentence similarities. #[arg(long)] prompt: Option<String>, @@ -49,12 +57,18 @@ struct Args { normalize_embeddings: bool, } -impl Args { - fn build_model_and_tokenizer(&self) -> Result<(t5::T5EncoderModel, Tokenizer)> { - let device = candle_examples::device(self.cpu)?; +struct T5ModelBuilder { + device: Device, + config: t5::Config, + weights_filename: PathBuf, +} + +impl T5ModelBuilder { + pub fn load(args: &Args) -> Result<(Self, Tokenizer)> { + let device = candle_examples::device(args.cpu)?; let default_model = "t5-small".to_string(); let default_revision = "refs/pr/15".to_string(); - let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) { + let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) { (Some(model_id), Some(revision)) => (model_id, revision), (Some(model_id), None) => (model_id, "main".to_string()), (None, Some(revision)) => (default_model, revision), @@ -62,7 +76,7 @@ impl Args { }; let repo = Repo::with_revision(model_id, RepoType::Model, revision); - let (config_filename, tokenizer_filename, weights_filename) = if self.offline { + let (config_filename, tokenizer_filename, weights_filename) = if args.offline { let cache = Cache::default().repo(repo); ( cache @@ -87,18 +101,36 @@ impl Args { let config = std::fs::read_to_string(config_filename)?; let config: t5::Config = serde_json::from_str(&config)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + Ok(( + Self { + device, + config, + weights_filename, + }, + tokenizer, + )) + } - let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; + pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> { + let weights = + unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? }; let weights = weights.deserialize()?; - let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); - let model = t5::T5EncoderModel::load(vb, &config)?; - Ok((model, tokenizer)) + let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device); + Ok(t5::T5EncoderModel::load(vb, &self.config)?) + } + + pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> { + let weights = + unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? }; + let weights = weights.deserialize()?; + let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device); + Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?) } } fn main() -> Result<()> { let args = Args::parse(); - let (model, mut tokenizer) = args.build_model_and_tokenizer()?; + let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?; let tokenizer = tokenizer .with_padding(None) .with_truncation(None) @@ -110,17 +142,51 @@ fn main() -> Result<()> { .map_err(E::msg)? .get_ids() .to_vec(); - let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?; - for idx in 0..args.n { + let input_token_ids = Tensor::new(&tokens[..], &builder.device)?.unsqueeze(0)?; + if !args.decode { + let model = builder.build_encoder()?; + for idx in 0..args.n { + let start = std::time::Instant::now(); + let ys = model.forward(&input_token_ids)?; + if idx == 0 { + println!("{ys}"); + } + println!("Took {:?}", start.elapsed()); + } + } else { + let model = builder.build_conditional_generation()?; + let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec(); + let mut logits_processor = LogitsProcessor::new(299792458, None, None); let start = std::time::Instant::now(); - let ys = model.forward(&token_ids)?; - if idx == 0 { - println!("{ys}"); + + for _index in 0.. { + if output_token_ids.len() > 512 { + break; + } + let decoder_token_ids = + Tensor::new(&output_token_ids[..], &builder.device)?.unsqueeze(0)?; + let logits = model.forward(&input_token_ids, &decoder_token_ids)?; + let next_token_id = logits_processor.sample(&logits.flatten_to(1)?)?; + if (next_token_id as usize) == builder.config.eos_token_id { + break; + } + output_token_ids.push(next_token_id); + if let Some(text) = tokenizer.id_to_token(next_token_id) { + let text = text.replace('▁', " ").replace("<0x0A>", "\n"); + print!("{text}"); + std::io::stdout().flush()?; + } } - println!("Took {:?}", start.elapsed()); + let dt = start.elapsed(); + println!( + "\n{} tokens generated ({:.2} token/s)\n", + tokens.len(), + tokens.len() as f64 / dt.as_secs_f64(), + ); } } None => { + let model = builder.build_encoder()?; let sentences = [ "The cat sits outside", "A man is playing guitar", |