diff options
Diffstat (limited to 'candle-examples/examples')
-rw-r--r-- | candle-examples/examples/bert/main.rs | 5 | ||||
-rw-r--r-- | candle-examples/examples/bigcode/main.rs | 5 | ||||
-rw-r--r-- | candle-examples/examples/falcon/main.rs | 6 | ||||
-rw-r--r-- | candle-examples/examples/llama/main.rs | 4 | ||||
-rw-r--r-- | candle-examples/examples/whisper/main.rs | 5 |
5 files changed, 10 insertions, 15 deletions
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 574755ed..7f0ae7b1 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -111,7 +111,10 @@ fn main() -> Result<()> { let device = &model.device; if let Some(prompt) = args.prompt { - let tokenizer = tokenizer.with_padding(None).with_truncation(None); + let tokenizer = tokenizer + .with_padding(None) + .with_truncation(None) + .map_err(E::msg)?; let tokens = tokenizer .encode(prompt, true) .map_err(E::msg)? diff --git a/candle-examples/examples/bigcode/main.rs b/candle-examples/examples/bigcode/main.rs index ac9c63c7..39b1de27 100644 --- a/candle-examples/examples/bigcode/main.rs +++ b/candle-examples/examples/bigcode/main.rs @@ -65,10 +65,7 @@ impl TextGeneration { let next_token = self.logits_processor.sample(&logits)?; tokens.push(next_token); new_tokens.push(next_token); - let token = self - .tokenizer - .decode(vec![next_token], true) - .map_err(E::msg)?; + let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?; print!("{token}"); std::io::stdout().flush()?; } diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs index c37d9a96..0df3a001 100644 --- a/candle-examples/examples/falcon/main.rs +++ b/candle-examples/examples/falcon/main.rs @@ -72,16 +72,14 @@ impl TextGeneration { "{} token: {} '{}'", index + 1, next_token, - self.tokenizer - .decode(vec![next_token], true) - .map_err(E::msg)? + self.tokenizer.decode(&[next_token], true).map_err(E::msg)? ); } let dt = start_gen.elapsed(); println!( "{sample_len} tokens generated ({} token/s)\n----\n{}\n----", sample_len as f64 / dt.as_secs_f64(), - self.tokenizer.decode(new_tokens, true).map_err(E::msg)? + self.tokenizer.decode(&new_tokens, true).map_err(E::msg)? ); Ok(()) } diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 9a62eba5..b1e112fd 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -223,7 +223,7 @@ fn main() -> Result<()> { "{} token: {} '{}'", index + 1, next_token, - tokenizer.decode(vec![next_token], true).map_err(E::msg)? + tokenizer.decode(&[next_token], true).map_err(E::msg)? ); } let dt = start_gen.elapsed(); @@ -231,7 +231,7 @@ fn main() -> Result<()> { "{} tokens generated ({} token/s)\n----\n{}\n----", args.sample_len, args.sample_len as f64 / dt.as_secs_f64(), - tokenizer.decode(new_tokens, true).map_err(E::msg)? + tokenizer.decode(&new_tokens, true).map_err(E::msg)? ); Ok(()) } diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 99919f8d..5c58c002 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -169,10 +169,7 @@ impl Decoder { } sum_logprob += prob.ln(); } - let text = self - .tokenizer - .decode(tokens.clone(), true) - .map_err(E::msg)?; + let text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?; let avg_logprob = sum_logprob / tokens.len() as f64; Ok(DecodingResult { |