summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/examples/llama/main.rs17
1 files changed, 6 insertions, 11 deletions
diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs
index 3a025683..2f9daec0 100644
--- a/candle-core/examples/llama/main.rs
+++ b/candle-core/examples/llama/main.rs
@@ -422,9 +422,6 @@ async fn main() -> Result<()> {
} else {
Device::new_cuda(0)?
};
- let api = Api::new()?;
- let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
- println!("building the model");
let config = Config::config_7b();
let cache = Cache::new(&device);
let start = std::time::Instant::now();
@@ -435,6 +432,9 @@ async fn main() -> Result<()> {
std::path::Path::new("llama-tokenizer.json").to_path_buf(),
)
} else {
+ let api = Api::new()?;
+ let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
+ println!("building the model");
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
let mut filenames = vec![];
for rfilename in [
@@ -483,14 +483,9 @@ async fn main() -> Result<()> {
logits_v
.iter()
.enumerate()
- .fold((0, logits_v[0]), |(idx_max, val_max), (idx, val)| {
- if &val_max > val {
- (idx_max, val_max)
- } else {
- (idx, *val)
- }
- })
- .0 as u32
+ .max_by(|(_, u), (_, v)| u.total_cmp(v))
+ .map(|(i, _)| i as u32)
+ .unwrap()
};
tokens.push(next_token);
new_tokens.push(next_token);