summaryrefslogtreecommitdiff
path: root/candle-core/examples/llama/main.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-06-29 10:14:12 +0100
committerGitHub <noreply@github.com>2023-06-29 10:14:12 +0100
commitc8fc9da73701e3f5d16695d6a38d55d471c7b82c (patch)
treed517caf2aedb654261a3dccadb95eeb4badc958e /candle-core/examples/llama/main.rs
parenteda46d2df20f1b6186473e8ee7a26c2d36aa05df (diff)
parentc9c468e1aaf0ce071b145f15aba830e9600fd6e6 (diff)
downloadcandle-c8fc9da73701e3f5d16695d6a38d55d471c7b82c.tar.gz
candle-c8fc9da73701e3f5d16695d6a38d55d471c7b82c.tar.bz2
candle-c8fc9da73701e3f5d16695d6a38d55d471c7b82c.zip
Merge pull request #33 from LaurentMazare/cuda-map
Simplify the dtype matchings in the cuda backend
Diffstat (limited to 'candle-core/examples/llama/main.rs')
-rw-r--r--candle-core/examples/llama/main.rs2
1 files changed, 2 insertions, 0 deletions
diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs
index eb681f4b..3fc893e3 100644
--- a/candle-core/examples/llama/main.rs
+++ b/candle-core/examples/llama/main.rs
@@ -487,6 +487,7 @@ fn main() -> Result<()> {
let mut rng = thread_rng();
let start_gen = std::time::Instant::now();
for index in 0..args.sample_len {
+ let start_gen = std::time::Instant::now();
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
let input = Tensor::new(ctxt, &device)?;
let logits = llama.forward(&input, &freqs_cis)?;
@@ -496,6 +497,7 @@ fn main() -> Result<()> {
let next_token = distr.sample(&mut rng) as u32;
tokens.push(next_token);
new_tokens.push(next_token);
+ println!("> {:?}", start_gen.elapsed());
println!(
"{} token: {} '{}'",
index + 1,