summaryrefslogtreecommitdiff
path: root/candle-examples/examples
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples')
-rw-r--r--candle-examples/examples/bert/main.rs5
-rw-r--r--candle-examples/examples/bigcode/main.rs5
-rw-r--r--candle-examples/examples/falcon/main.rs6
-rw-r--r--candle-examples/examples/llama/main.rs4
-rw-r--r--candle-examples/examples/whisper/main.rs5
5 files changed, 10 insertions, 15 deletions
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs
index 574755ed..7f0ae7b1 100644
--- a/candle-examples/examples/bert/main.rs
+++ b/candle-examples/examples/bert/main.rs
@@ -111,7 +111,10 @@ fn main() -> Result<()> {
let device = &model.device;
if let Some(prompt) = args.prompt {
- let tokenizer = tokenizer.with_padding(None).with_truncation(None);
+ let tokenizer = tokenizer
+ .with_padding(None)
+ .with_truncation(None)
+ .map_err(E::msg)?;
let tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
diff --git a/candle-examples/examples/bigcode/main.rs b/candle-examples/examples/bigcode/main.rs
index ac9c63c7..39b1de27 100644
--- a/candle-examples/examples/bigcode/main.rs
+++ b/candle-examples/examples/bigcode/main.rs
@@ -65,10 +65,7 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
- let token = self
- .tokenizer
- .decode(vec![next_token], true)
- .map_err(E::msg)?;
+ let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
print!("{token}");
std::io::stdout().flush()?;
}
diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs
index c37d9a96..0df3a001 100644
--- a/candle-examples/examples/falcon/main.rs
+++ b/candle-examples/examples/falcon/main.rs
@@ -72,16 +72,14 @@ impl TextGeneration {
"{} token: {} '{}'",
index + 1,
next_token,
- self.tokenizer
- .decode(vec![next_token], true)
- .map_err(E::msg)?
+ self.tokenizer.decode(&[next_token], true).map_err(E::msg)?
);
}
let dt = start_gen.elapsed();
println!(
"{sample_len} tokens generated ({} token/s)\n----\n{}\n----",
sample_len as f64 / dt.as_secs_f64(),
- self.tokenizer.decode(new_tokens, true).map_err(E::msg)?
+ self.tokenizer.decode(&new_tokens, true).map_err(E::msg)?
);
Ok(())
}
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index 9a62eba5..b1e112fd 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -223,7 +223,7 @@ fn main() -> Result<()> {
"{} token: {} '{}'",
index + 1,
next_token,
- tokenizer.decode(vec![next_token], true).map_err(E::msg)?
+ tokenizer.decode(&[next_token], true).map_err(E::msg)?
);
}
let dt = start_gen.elapsed();
@@ -231,7 +231,7 @@ fn main() -> Result<()> {
"{} tokens generated ({} token/s)\n----\n{}\n----",
args.sample_len,
args.sample_len as f64 / dt.as_secs_f64(),
- tokenizer.decode(new_tokens, true).map_err(E::msg)?
+ tokenizer.decode(&new_tokens, true).map_err(E::msg)?
);
Ok(())
}
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index 99919f8d..5c58c002 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -169,10 +169,7 @@ impl Decoder {
}
sum_logprob += prob.ln();
}
- let text = self
- .tokenizer
- .decode(tokens.clone(), true)
- .map_err(E::msg)?;
+ let text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
let avg_logprob = sum_logprob / tokens.len() as f64;
Ok(DecodingResult {