summaryrefslogtreecommitdiff
path: root/candle-examples/examples/musicgen
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-17 09:00:45 +0200
committerGitHub <noreply@github.com>2023-09-17 08:00:45 +0100
commit1a276b5da79a4bb2305dde7368b800d165599819 (patch)
tree5270c7d9c0b6e345cfd65c3d74690c3488d65aa5 /candle-examples/examples/musicgen
parent8658df348527cabcd722bfe2e9e48aba3c7f8e96 (diff)
downloadcandle-1a276b5da79a4bb2305dde7368b800d165599819.tar.gz
candle-1a276b5da79a4bb2305dde7368b800d165599819.tar.bz2
candle-1a276b5da79a4bb2305dde7368b800d165599819.zip
Add a KV cache to T5. (#873)
* Add a KV cache to T5. * Suggest using release mode. * Use the kv cache in decoding. * Add a comment.
Diffstat (limited to 'candle-examples/examples/musicgen')
-rw-r--r--candle-examples/examples/musicgen/main.rs2
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-examples/examples/musicgen/main.rs b/candle-examples/examples/musicgen/main.rs
index df8c3135..0fae67b5 100644
--- a/candle-examples/examples/musicgen/main.rs
+++ b/candle-examples/examples/musicgen/main.rs
@@ -77,7 +77,7 @@ fn main() -> Result<()> {
let model = model.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
let config = GenConfig::small();
- let model = MusicgenForConditionalGeneration::load(vb, config)?;
+ let mut model = MusicgenForConditionalGeneration::load(vb, config)?;
let tokens = tokenizer
.encode(args.prompt.as_str(), true)