summaryrefslogtreecommitdiff
path: root/candle-examples/examples/t5/main.rs
diff options
context:
space:
mode:
authorJuarez Bochi <juarez.bochi@grammarly.com>2023-09-15 13:05:12 -0700
committerGitHub <noreply@github.com>2023-09-15 22:05:12 +0200
commit3e49f8fce52c6b8f361bfd37d541a99b5e1f8c63 (patch)
treefe8214d4cba3974bcf085bb8e4f758c74aa13136 /candle-examples/examples/t5/main.rs
parentc2007ac88fb0dd6fa6f82f6624693a0095db2edb (diff)
downloadcandle-3e49f8fce52c6b8f361bfd37d541a99b5e1f8c63.tar.gz
candle-3e49f8fce52c6b8f361bfd37d541a99b5e1f8c63.tar.bz2
candle-3e49f8fce52c6b8f361bfd37d541a99b5e1f8c63.zip
Implement T5 decoding (#864)
* Load t5 decoder * Run enc, dec, and lm head, but no cross attn * Cross-attention over key_value_states * New arg for decoder input ids * Add mask, don't forward position biases through decoder * Update t5 examples * Clippy + rustfmt
Diffstat (limited to 'candle-examples/examples/t5/main.rs')
-rw-r--r--candle-examples/examples/t5/main.rs104
1 files changed, 85 insertions, 19 deletions
diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs
index 1e182974..00291609 100644
--- a/candle-examples/examples/t5/main.rs
+++ b/candle-examples/examples/t5/main.rs
@@ -3,18 +3,22 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
+use std::io::Write;
+use std::path::PathBuf;
+
use candle_transformers::models::t5;
use anyhow::{anyhow, Error as E, Result};
-use candle::{DType, Tensor};
+use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
+use candle_transformers::generation::LogitsProcessor;
use clap::Parser;
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
use tokenizers::Tokenizer;
const DTYPE: DType = DType::F32;
-#[derive(Parser, Debug)]
+#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
@@ -36,7 +40,11 @@ struct Args {
#[arg(long)]
revision: Option<String>,
- /// Compute embeddings for this prompt, otherwise compute sentence similarities.
+ /// Enable decoding.
+ #[arg(long)]
+ decode: bool,
+
+ /// Use this prompt, otherwise compute sentence similarities.
#[arg(long)]
prompt: Option<String>,
@@ -49,12 +57,18 @@ struct Args {
normalize_embeddings: bool,
}
-impl Args {
- fn build_model_and_tokenizer(&self) -> Result<(t5::T5EncoderModel, Tokenizer)> {
- let device = candle_examples::device(self.cpu)?;
+struct T5ModelBuilder {
+ device: Device,
+ config: t5::Config,
+ weights_filename: PathBuf,
+}
+
+impl T5ModelBuilder {
+ pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
+ let device = candle_examples::device(args.cpu)?;
let default_model = "t5-small".to_string();
let default_revision = "refs/pr/15".to_string();
- let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
+ let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
@@ -62,7 +76,7 @@ impl Args {
};
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
- let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
+ let (config_filename, tokenizer_filename, weights_filename) = if args.offline {
let cache = Cache::default().repo(repo);
(
cache
@@ -87,18 +101,36 @@ impl Args {
let config = std::fs::read_to_string(config_filename)?;
let config: t5::Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+ Ok((
+ Self {
+ device,
+ config,
+ weights_filename,
+ },
+ tokenizer,
+ ))
+ }
- let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
+ pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> {
+ let weights =
+ unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? };
let weights = weights.deserialize()?;
- let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
- let model = t5::T5EncoderModel::load(vb, &config)?;
- Ok((model, tokenizer))
+ let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device);
+ Ok(t5::T5EncoderModel::load(vb, &self.config)?)
+ }
+
+ pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> {
+ let weights =
+ unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? };
+ let weights = weights.deserialize()?;
+ let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device);
+ Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
}
}
fn main() -> Result<()> {
let args = Args::parse();
- let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
+ let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
@@ -110,17 +142,51 @@ fn main() -> Result<()> {
.map_err(E::msg)?
.get_ids()
.to_vec();
- let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;
- for idx in 0..args.n {
+ let input_token_ids = Tensor::new(&tokens[..], &builder.device)?.unsqueeze(0)?;
+ if !args.decode {
+ let model = builder.build_encoder()?;
+ for idx in 0..args.n {
+ let start = std::time::Instant::now();
+ let ys = model.forward(&input_token_ids)?;
+ if idx == 0 {
+ println!("{ys}");
+ }
+ println!("Took {:?}", start.elapsed());
+ }
+ } else {
+ let model = builder.build_conditional_generation()?;
+ let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
+ let mut logits_processor = LogitsProcessor::new(299792458, None, None);
let start = std::time::Instant::now();
- let ys = model.forward(&token_ids)?;
- if idx == 0 {
- println!("{ys}");
+
+ for _index in 0.. {
+ if output_token_ids.len() > 512 {
+ break;
+ }
+ let decoder_token_ids =
+ Tensor::new(&output_token_ids[..], &builder.device)?.unsqueeze(0)?;
+ let logits = model.forward(&input_token_ids, &decoder_token_ids)?;
+ let next_token_id = logits_processor.sample(&logits.flatten_to(1)?)?;
+ if (next_token_id as usize) == builder.config.eos_token_id {
+ break;
+ }
+ output_token_ids.push(next_token_id);
+ if let Some(text) = tokenizer.id_to_token(next_token_id) {
+ let text = text.replace('▁', " ").replace("<0x0A>", "\n");
+ print!("{text}");
+ std::io::stdout().flush()?;
+ }
}
- println!("Took {:?}", start.elapsed());
+ let dt = start.elapsed();
+ println!(
+ "\n{} tokens generated ({:.2} token/s)\n",
+ tokens.len(),
+ tokens.len() as f64 / dt.as_secs_f64(),
+ );
}
}
None => {
+ let model = builder.build_encoder()?;
let sentences = [
"The cat sits outside",
"A man is playing guitar",