diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-17 09:00:45 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-17 08:00:45 +0100 |
commit | 1a276b5da79a4bb2305dde7368b800d165599819 (patch) | |
tree | 5270c7d9c0b6e345cfd65c3d74690c3488d65aa5 /candle-examples/examples/t5 | |
parent | 8658df348527cabcd722bfe2e9e48aba3c7f8e96 (diff) | |
download | candle-1a276b5da79a4bb2305dde7368b800d165599819.tar.gz candle-1a276b5da79a4bb2305dde7368b800d165599819.tar.bz2 candle-1a276b5da79a4bb2305dde7368b800d165599819.zip |
Add a KV cache to T5. (#873)
* Add a KV cache to T5.
* Suggest using release mode.
* Use the kv cache in decoding.
* Add a comment.
Diffstat (limited to 'candle-examples/examples/t5')
-rw-r--r-- | candle-examples/examples/t5/README.md | 4 | ||||
-rw-r--r-- | candle-examples/examples/t5/main.rs | 37 |
2 files changed, 19 insertions, 22 deletions
diff --git a/candle-examples/examples/t5/README.md b/candle-examples/examples/t5/README.md index c6ea2125..6a406467 100644 --- a/candle-examples/examples/t5/README.md +++ b/candle-examples/examples/t5/README.md @@ -3,7 +3,7 @@ ## Encoder-decoder example: ```bash -$ cargo run --example t5 -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode +$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode ... Running on CPU, to run on GPU, build this example with `--features cuda` Eine schöne Kerze. @@ -13,7 +13,7 @@ Running on CPU, to run on GPU, build this example with `--features cuda` ## Sentence embedding example: ```bash -$ cargo run --example t5 -- --model-id "t5-small" --prompt "A beautiful candle." +$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle." ... [[[ 0.0515, -0.0541, -0.0761, ..., -0.0392, 0.1511, -0.0265], [-0.0974, 0.0998, -0.1659, ..., -0.2450, 0.1738, -0.0164], diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs index 00291609..c432e004 100644 --- a/candle-examples/examples/t5/main.rs +++ b/candle-examples/examples/t5/main.rs @@ -48,10 +48,6 @@ struct Args { #[arg(long)] prompt: Option<String>, - /// The number of times to run the prompt. - #[arg(long, default_value = "1")] - n: usize, - /// L2 normalization for embeddings. #[arg(long, default_value = "true")] normalize_embeddings: bool, @@ -131,6 +127,7 @@ impl T5ModelBuilder { fn main() -> Result<()> { let args = Args::parse(); let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?; + let device = &builder.device; let tokenizer = tokenizer .with_padding(None) .with_truncation(None) @@ -142,32 +139,32 @@ fn main() -> Result<()> { .map_err(E::msg)? .get_ids() .to_vec(); - let input_token_ids = Tensor::new(&tokens[..], &builder.device)?.unsqueeze(0)?; + let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; if !args.decode { - let model = builder.build_encoder()?; - for idx in 0..args.n { - let start = std::time::Instant::now(); - let ys = model.forward(&input_token_ids)?; - if idx == 0 { - println!("{ys}"); - } - println!("Took {:?}", start.elapsed()); - } + let mut model = builder.build_encoder()?; + let start = std::time::Instant::now(); + let ys = model.forward(&input_token_ids)?; + println!("{ys}"); + println!("Took {:?}", start.elapsed()); } else { - let model = builder.build_conditional_generation()?; + let mut model = builder.build_conditional_generation()?; let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec(); let mut logits_processor = LogitsProcessor::new(299792458, None, None); let start = std::time::Instant::now(); - for _index in 0.. { + for index in 0.. { if output_token_ids.len() > 512 { break; } - let decoder_token_ids = - Tensor::new(&output_token_ids[..], &builder.device)?.unsqueeze(0)?; + let decoder_token_ids = if index == 0 || !builder.config.use_cache { + Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)? + } else { + let last_token = *output_token_ids.last().unwrap(); + Tensor::new(&[last_token], device)?.unsqueeze(0)? + }; let logits = model.forward(&input_token_ids, &decoder_token_ids)?; let next_token_id = logits_processor.sample(&logits.flatten_to(1)?)?; - if (next_token_id as usize) == builder.config.eos_token_id { + if next_token_id as usize == builder.config.eos_token_id { break; } output_token_ids.push(next_token_id); @@ -186,7 +183,7 @@ fn main() -> Result<()> { } } None => { - let model = builder.build_encoder()?; + let mut model = builder.build_encoder()?; let sentences = [ "The cat sits outside", "A man is playing guitar", |