summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama
diff options
context:
space:
mode:
authoroptman <optman@gmail.com>2024-01-06 18:43:01 +0800
committerGitHub <noreply@github.com>2024-01-06 11:43:01 +0100
commit84250bf52f58528cf59dca3b82effd9f07a13cc7 (patch)
treed8cf70e80954abc8e8f1aa58d0203960e81cba85 /candle-examples/examples/llama
parent8d1a57c9a0465b201e4e9e410e2b8fcde37b35f7 (diff)
downloadcandle-84250bf52f58528cf59dca3b82effd9f07a13cc7.tar.gz
candle-84250bf52f58528cf59dca3b82effd9f07a13cc7.tar.bz2
candle-84250bf52f58528cf59dca3b82effd9f07a13cc7.zip
fix index_pos bug when kv cache is disabled. (#1517)
* fix index_pos bug when kv cache is disabled * Tweak the fix. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-examples/examples/llama')
-rw-r--r--candle-examples/examples/llama/main.rs8
1 files changed, 4 insertions, 4 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index 46f474bb..251c184b 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -165,14 +165,14 @@ fn main() -> Result<()> {
let mut index_pos = 0;
let mut token_generated = 0;
for index in 0..args.sample_len {
- let context_size = if cache.use_kv_cache && index > 0 {
- 1
+ let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
+ (1, index_pos)
} else {
- tokens.len()
+ (tokens.len(), 0)
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
- let logits = llama.forward(&input, index_pos)?;
+ let logits = llama.forward(&input, context_index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits