summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-12-31 09:21:41 +0100
committerGitHub <noreply@github.com>2024-12-31 09:21:41 +0100
commitd60eba140820326ffc7ec39a8982e91feb462732 (patch)
tree4c97e5660205b388be3e5121aa2e2857f60344de
parente38e2a85dd21cbb07dbca381ac3755f2b7909605 (diff)
downloadcandle-d60eba140820326ffc7ec39a8982e91feb462732.tar.gz
candle-d60eba140820326ffc7ec39a8982e91feb462732.tar.bz2
candle-d60eba140820326ffc7ec39a8982e91feb462732.zip
Streamline the glm4 example. (#2694)
-rw-r--r--candle-examples/examples/flux/main.rs6
-rw-r--r--candle-examples/examples/glm4/README.org39
-rw-r--r--candle-examples/examples/glm4/main.rs201
3 files changed, 99 insertions, 147 deletions
diff --git a/candle-examples/examples/flux/main.rs b/candle-examples/examples/flux/main.rs
index 943db112..12439892 100644
--- a/candle-examples/examples/flux/main.rs
+++ b/candle-examples/examples/flux/main.rs
@@ -250,7 +250,11 @@ fn run(args: Args) -> Result<()> {
};
println!("img\n{img}");
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
- candle_examples::save_image(&img.i(0)?, "out.jpg")?;
+ let filename = match args.seed {
+ None => "out.jpg".to_string(),
+ Some(s) => format!("out-{s}.jpg"),
+ };
+ candle_examples::save_image(&img.i(0)?, filename)?;
Ok(())
}
diff --git a/candle-examples/examples/glm4/README.org b/candle-examples/examples/glm4/README.org
index 364f61e8..a584f6c7 100644
--- a/candle-examples/examples/glm4/README.org
+++ b/candle-examples/examples/glm4/README.org
@@ -7,48 +7,25 @@ GLM-4-9B is the open-source version of the latest generation of pre-trained mode
** Running with ~cuda~
#+begin_src shell
- cargo run --example glm4 --release --features cuda
+ cargo run --example glm4 --release --features cuda -- --prompt "Hello world"
#+end_src
** Running with ~cpu~
#+begin_src shell
- cargo run --example glm4 --release -- --cpu
+ cargo run --example glm4 --release -- --cpu--prompt "Hello world"
#+end_src
** Output Example
#+begin_src shell
-cargo run --example glm4 --release --features cuda -- --sample-len 500 --cache .
- Finished release [optimized] target(s) in 0.24s
- Running `/root/candle/target/release/examples/glm4 --sample-len 500 --cache .`
+cargo run --features cuda -r --example glm4 -- --prompt "Hello "
+
avx: true, neon: false, simd128: false, f16c: true
temp: 0.60 repeat-penalty: 1.20 repeat-last-n: 64
-cache path .
-retrieved the files in 6.88963ms
-loaded the model in 6.113752297s
+retrieved the files in 6.454375ms
+loaded the model in 3.652383779s
starting the inference loop
-[欢迎使用GLM-4,请输入prompt]
-请你告诉我什么是FFT
-266 tokens generated (34.50 token/s)
-Result:
-。Fast Fourier Transform (FFT) 是一种快速计算离散傅里叶变换(DFT)的方法,它广泛应用于信号处理、图像处理和数据分析等领域。
-
-具体来说,FFT是一种将时域数据转换为频域数据的算法。在数字信号处理中,我们通常需要知道信号的频率成分,这就需要进行傅立叶变换。传统的傅立叶变换的计算复杂度较高,而 FFT 则大大提高了计算效率,使得大规模的 DFT 换成为可能。
-
-以下是使用 Python 中的 numpy 进行 FFT 的简单示例:
-
-```python
-import numpy as np
-
-# 创建一个时域信号
-t = np.linspace(0, 1, num=100)
-f = np.sin(2*np.pi*5*t) + 3*np.cos(2*np.pi*10*t)
-
-# 对该信号做FFT变换,并计算其幅值谱
-fft_result = np.fft.fftshift(np.abs(np.fft.fft(f)))
-
-```
-
-在这个例子中,我们首先创建了一个时域信号 f。然后我们对这个信号进行了 FFT 换,得到了一个频域结果 fft_result。
+Hello 2018, hello new year! I’m so excited to be back and sharing with you all my favorite things from the past month. This is a monthly series where I share what’s been inspiring me lately in hopes that it will inspire you too!
+...
#+end_src
This example will read prompt from stdin
diff --git a/candle-examples/examples/glm4/main.rs b/candle-examples/examples/glm4/main.rs
index 55a27f34..ced3841d 100644
--- a/candle-examples/examples/glm4/main.rs
+++ b/candle-examples/examples/glm4/main.rs
@@ -12,120 +12,97 @@ struct TextGeneration {
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
- repeat_penalty: f32,
- repeat_last_n: usize,
- verbose_prompt: bool,
+ args: Args,
dtype: DType,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
- fn new(
- model: Model,
- tokenizer: Tokenizer,
- seed: u64,
- temp: Option<f64>,
- top_p: Option<f64>,
- repeat_penalty: f32,
- repeat_last_n: usize,
- verbose_prompt: bool,
- device: &Device,
- dtype: DType,
- ) -> Self {
- let logits_processor = LogitsProcessor::new(seed, temp, top_p);
+ fn new(model: Model, tokenizer: Tokenizer, args: Args, device: &Device, dtype: DType) -> Self {
+ let logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
Self {
model,
tokenizer,
logits_processor,
- repeat_penalty,
- repeat_last_n,
- verbose_prompt,
+ args,
device: device.clone(),
dtype,
}
}
- fn run(&mut self, sample_len: usize) -> anyhow::Result<()> {
- use std::io::BufRead;
- use std::io::BufReader;
+ fn run(&mut self) -> anyhow::Result<()> {
use std::io::Write;
+ let args = &self.args;
println!("starting the inference loop");
- println!("[欢迎使用GLM-4,请输入prompt]");
- let stdin = std::io::stdin();
- let reader = BufReader::new(stdin);
- for line in reader.lines() {
- let line = line.expect("Failed to read line");
-
- let tokens = self.tokenizer.encode(line, true).expect("tokens error");
- if tokens.is_empty() {
- panic!("Empty prompts are not supported in the chatglm model.")
- }
- if self.verbose_prompt {
- for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
- let token = token.replace('▁', " ").replace("<0x0A>", "\n");
- println!("{id:7} -> '{token}'");
- }
+
+ let tokens = self
+ .tokenizer
+ .encode(args.prompt.to_string(), true)
+ .expect("tokens error");
+ if tokens.is_empty() {
+ panic!("Empty prompts are not supported in the chatglm model.")
+ }
+ if args.verbose {
+ for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
+ let token = token.replace('▁', " ").replace("<0x0A>", "\n");
+ println!("{id:7} -> '{token}'");
}
- let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
- Some(token) => *token,
- None => panic!("cannot find the endoftext token"),
+ } else {
+ print!("{}", &args.prompt);
+ std::io::stdout().flush()?;
+ }
+ let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
+ Some(token) => *token,
+ None => panic!("cannot find the endoftext token"),
+ };
+ let mut tokens = tokens.get_ids().to_vec();
+ let mut generated_tokens = 0usize;
+
+ std::io::stdout().flush().expect("output flush error");
+ let start_gen = std::time::Instant::now();
+
+ for index in 0..args.sample_len {
+ let context_size = if index > 0 { 1 } else { tokens.len() };
+ let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
+ let logits = self.model.forward(&input)?;
+ let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
+ let logits = if args.repeat_penalty == 1. {
+ logits
+ } else {
+ let start_at = tokens.len().saturating_sub(args.repeat_last_n);
+ candle_transformers::utils::apply_repeat_penalty(
+ &logits,
+ args.repeat_penalty,
+ &tokens[start_at..],
+ )?
};
- let mut tokens = tokens.get_ids().to_vec();
- let mut generated_tokens = 0usize;
-
- std::io::stdout().flush().expect("output flush error");
- let start_gen = std::time::Instant::now();
-
- let mut count = 0;
- let mut result = vec![];
- for index in 0..sample_len {
- count += 1;
- let context_size = if index > 0 { 1 } else { tokens.len() };
- let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
- let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
- let logits = self.model.forward(&input)?;
- let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
- let logits = if self.repeat_penalty == 1. {
- logits
- } else {
- let start_at = tokens.len().saturating_sub(self.repeat_last_n);
- candle_transformers::utils::apply_repeat_penalty(
- &logits,
- self.repeat_penalty,
- &tokens[start_at..],
- )?
- };
-
- let next_token = self.logits_processor.sample(&logits)?;
- tokens.push(next_token);
- generated_tokens += 1;
- if next_token == eos_token {
- break;
- }
- let token = self
- .tokenizer
- .decode(&[next_token], true)
- .expect("Token error");
- if self.verbose_prompt {
- println!(
- "[Count: {}] [Raw Token: {}] [Decode Token: {}]",
- count, next_token, token
- );
- }
- result.push(token);
- std::io::stdout().flush()?;
+
+ let next_token = self.logits_processor.sample(&logits)?;
+ tokens.push(next_token);
+ generated_tokens += 1;
+ if next_token == eos_token {
+ break;
}
- let dt = start_gen.elapsed();
- println!(
- "\n{generated_tokens} tokens generated ({:.2} token/s)",
- generated_tokens as f64 / dt.as_secs_f64(),
- );
- println!("Result:");
- for tokens in result {
- print!("{tokens}");
+ let token = self
+ .tokenizer
+ .decode(&[next_token], true)
+ .expect("token decode error");
+ if args.verbose {
+ println!(
+ "[Count: {}] [Raw Token: {}] [Decode Token: {}]",
+ generated_tokens, next_token, token
+ );
+ } else {
+ print!("{token}");
+ std::io::stdout().flush()?;
}
- self.model.reset_kv_cache(); // clean the cache
}
+ let dt = start_gen.elapsed();
+ println!(
+ "\n{generated_tokens} tokens generated ({:.2} token/s)",
+ generated_tokens as f64 / dt.as_secs_f64(),
+ );
Ok(())
}
}
@@ -141,7 +118,11 @@ struct Args {
/// Display the token for the specified prompt.
#[arg(long)]
- verbose_prompt: bool,
+ prompt: String,
+
+ /// Display the tokens for the specified prompt and outputs.
+ #[arg(long)]
+ verbose: bool,
/// The temperature used to generate samples.
#[arg(long)]
@@ -197,28 +178,29 @@ fn main() -> anyhow::Result<()> {
);
let start = std::time::Instant::now();
- println!("cache path {}", args.cache_path);
- let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
- .build()
- .map_err(anyhow::Error::msg)?;
+ let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(
+ args.cache_path.to_string().into(),
+ ))
+ .build()
+ .map_err(anyhow::Error::msg)?;
- let model_id = match args.model_id {
+ let model_id = match args.model_id.as_ref() {
Some(model_id) => model_id.to_string(),
None => "THUDM/glm-4-9b".to_string(),
};
- let revision = match args.revision {
+ let revision = match args.revision.as_ref() {
Some(rev) => rev.to_string(),
None => "main".to_string(),
};
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
- let tokenizer_filename = match args.tokenizer {
+ let tokenizer_filename = match args.tokenizer.as_ref() {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("THUDM/codegeex4-all-9b".to_string())
.get("tokenizer.json")
.map_err(anyhow::Error::msg)?,
};
- let filenames = match args.weight_file {
+ let filenames = match args.weight_file.as_ref() {
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
@@ -238,18 +220,7 @@ fn main() -> anyhow::Result<()> {
println!("loaded the model in {:?}", start.elapsed());
- let mut pipeline = TextGeneration::new(
- model,
- tokenizer,
- args.seed,
- args.temperature,
- args.top_p,
- args.repeat_penalty,
- args.repeat_last_n,
- args.verbose_prompt,
- &device,
- dtype,
- );
- pipeline.run(args.sample_len)?;
+ let mut pipeline = TextGeneration::new(model, tokenizer, args, &device, dtype);
+ pipeline.run()?;
Ok(())
}