summaryrefslogtreecommitdiff
path: root/candle-examples/examples/t5
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-17 09:00:45 +0200
committerGitHub <noreply@github.com>2023-09-17 08:00:45 +0100
commit1a276b5da79a4bb2305dde7368b800d165599819 (patch)
tree5270c7d9c0b6e345cfd65c3d74690c3488d65aa5 /candle-examples/examples/t5
parent8658df348527cabcd722bfe2e9e48aba3c7f8e96 (diff)
downloadcandle-1a276b5da79a4bb2305dde7368b800d165599819.tar.gz
candle-1a276b5da79a4bb2305dde7368b800d165599819.tar.bz2
candle-1a276b5da79a4bb2305dde7368b800d165599819.zip
Add a KV cache to T5. (#873)
* Add a KV cache to T5. * Suggest using release mode. * Use the kv cache in decoding. * Add a comment.
Diffstat (limited to 'candle-examples/examples/t5')
-rw-r--r--candle-examples/examples/t5/README.md4
-rw-r--r--candle-examples/examples/t5/main.rs37
2 files changed, 19 insertions, 22 deletions
diff --git a/candle-examples/examples/t5/README.md b/candle-examples/examples/t5/README.md
index c6ea2125..6a406467 100644
--- a/candle-examples/examples/t5/README.md
+++ b/candle-examples/examples/t5/README.md
@@ -3,7 +3,7 @@
## Encoder-decoder example:
```bash
-$ cargo run --example t5 -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
+$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
...
Running on CPU, to run on GPU, build this example with `--features cuda`
Eine schöne Kerze.
@@ -13,7 +13,7 @@ Running on CPU, to run on GPU, build this example with `--features cuda`
## Sentence embedding example:
```bash
-$ cargo run --example t5 -- --model-id "t5-small" --prompt "A beautiful candle."
+$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle."
...
[[[ 0.0515, -0.0541, -0.0761, ..., -0.0392, 0.1511, -0.0265],
[-0.0974, 0.0998, -0.1659, ..., -0.2450, 0.1738, -0.0164],
diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs
index 00291609..c432e004 100644
--- a/candle-examples/examples/t5/main.rs
+++ b/candle-examples/examples/t5/main.rs
@@ -48,10 +48,6 @@ struct Args {
#[arg(long)]
prompt: Option<String>,
- /// The number of times to run the prompt.
- #[arg(long, default_value = "1")]
- n: usize,
-
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
@@ -131,6 +127,7 @@ impl T5ModelBuilder {
fn main() -> Result<()> {
let args = Args::parse();
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
+ let device = &builder.device;
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
@@ -142,32 +139,32 @@ fn main() -> Result<()> {
.map_err(E::msg)?
.get_ids()
.to_vec();
- let input_token_ids = Tensor::new(&tokens[..], &builder.device)?.unsqueeze(0)?;
+ let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
if !args.decode {
- let model = builder.build_encoder()?;
- for idx in 0..args.n {
- let start = std::time::Instant::now();
- let ys = model.forward(&input_token_ids)?;
- if idx == 0 {
- println!("{ys}");
- }
- println!("Took {:?}", start.elapsed());
- }
+ let mut model = builder.build_encoder()?;
+ let start = std::time::Instant::now();
+ let ys = model.forward(&input_token_ids)?;
+ println!("{ys}");
+ println!("Took {:?}", start.elapsed());
} else {
- let model = builder.build_conditional_generation()?;
+ let mut model = builder.build_conditional_generation()?;
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
let mut logits_processor = LogitsProcessor::new(299792458, None, None);
let start = std::time::Instant::now();
- for _index in 0.. {
+ for index in 0.. {
if output_token_ids.len() > 512 {
break;
}
- let decoder_token_ids =
- Tensor::new(&output_token_ids[..], &builder.device)?.unsqueeze(0)?;
+ let decoder_token_ids = if index == 0 || !builder.config.use_cache {
+ Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
+ } else {
+ let last_token = *output_token_ids.last().unwrap();
+ Tensor::new(&[last_token], device)?.unsqueeze(0)?
+ };
let logits = model.forward(&input_token_ids, &decoder_token_ids)?;
let next_token_id = logits_processor.sample(&logits.flatten_to(1)?)?;
- if (next_token_id as usize) == builder.config.eos_token_id {
+ if next_token_id as usize == builder.config.eos_token_id {
break;
}
output_token_ids.push(next_token_id);
@@ -186,7 +183,7 @@ fn main() -> Result<()> {
}
}
None => {
- let model = builder.build_encoder()?;
+ let mut model = builder.build_encoder()?;
let sentences = [
"The cat sits outside",
"A man is playing guitar",