diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-09-30 19:31:14 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-30 19:31:14 +0200 |
commit | 683ab698def755c24cec9987069d25efcf831fc4 (patch) | |
tree | 84d0bd8ad2f5d7a00f67050c83520326d947b2fe /candle-examples/examples | |
parent | 2f49e1b5349f4e56677ec0d3dc3fe98f9cbb87c7 (diff) | |
download | candle-683ab698def755c24cec9987069d25efcf831fc4.tar.gz candle-683ab698def755c24cec9987069d25efcf831fc4.tar.bz2 candle-683ab698def755c24cec9987069d25efcf831fc4.zip |
Add Pixtral. (#2521)
* Add Pixtral.
* More pixtral vision encoder.
* Sketch a pixtral example.
* Sketch a pixtral example.
* Better image loading.
* Support loading images embedded in safetensor files.
* Clippy fixes.
* Add the llava multimodal adapter.
* Add more of the llava bits.
* Add the pixtral config.
* More pixtral inference.
* Add the text generation bits.
* Get the example to work.
* Bugfix.
* Run some bits of the model in f32.
* Blessed version :)
* Better rope frequency computations.
* README update.
Diffstat (limited to 'candle-examples/examples')
-rw-r--r-- | candle-examples/examples/pixtral/README.md | 28 | ||||
-rw-r--r-- | candle-examples/examples/pixtral/main.rs | 336 |
2 files changed, 364 insertions, 0 deletions
diff --git a/candle-examples/examples/pixtral/README.md b/candle-examples/examples/pixtral/README.md new file mode 100644 index 00000000..77677571 --- /dev/null +++ b/candle-examples/examples/pixtral/README.md @@ -0,0 +1,28 @@ +# pixtral + +Pixtral-12B is a 12B text+vision model. + +[Blog Post](https://mistral.ai/news/pixtral-12b/) - +[HF Model Card](https://huggingface.co/mistralai/Pixtral-12B-2409) - +[HF Community Model Card](https://huggingface.co/mistral-community/pixtral-12b). + +```bash +cargo run --profile=release-with-debug --features cuda --example pixtral -- \ + --image candle-examples/examples/flux/assets/flux-robot.jpg +``` + +``` +Describe the image. + +The image depicts a charming, rustic robot standing on a sandy beach at sunset. +The robot has a vintage, steampunk aesthetic with visible gears and mechanical +parts. It is holding a small lantern in one hand, which emits a warm glow, and +its other arm is extended forward as if reaching out or guiding the way. The +robot's body is adorned with the word "RUST" in bright orange letters, adding to +its rustic theme. + +The background features a dramatic sky filled with clouds, illuminated by the +setting sun, casting a golden hue over the scene. Gentle waves lap against the +shore, creating a serene and picturesque atmosphere. The overall mood of the +image is whimsical and nostalgic, evoking a sense of adventure and tranquility. +``` diff --git a/candle-examples/examples/pixtral/main.rs b/candle-examples/examples/pixtral/main.rs new file mode 100644 index 00000000..8e48b60b --- /dev/null +++ b/candle-examples/examples/pixtral/main.rs @@ -0,0 +1,336 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::{Error as E, Result}; +use clap::Parser; + +use candle_transformers::models::pixtral::{vision_model, Config, Model}; + +use candle::{DType, Device, Module, Tensor}; +use candle_examples::token_output_stream::TokenOutputStream; +use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::Tokenizer; + +struct TextGeneration { + model: Model, + image: Tensor, + device: Device, + tokenizer: TokenOutputStream, + logits_processor: LogitsProcessor, + repeat_penalty: f32, + repeat_last_n: usize, +} + +impl TextGeneration { + #[allow(clippy::too_many_arguments)] + fn new( + model: Model, + image: Tensor, + tokenizer: Tokenizer, + seed: u64, + temp: Option<f64>, + top_p: Option<f64>, + repeat_penalty: f32, + repeat_last_n: usize, + device: &Device, + ) -> Self { + let logits_processor = LogitsProcessor::new(seed, temp, top_p); + Self { + model, + image, + tokenizer: TokenOutputStream::new(tokenizer), + logits_processor, + repeat_penalty, + repeat_last_n, + device: device.clone(), + } + } + + fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { + use std::io::Write; + self.tokenizer.clear(); + let mut tokens = self + .tokenizer + .tokenizer() + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let mut generated_tokens = 0usize; + let get_token = |v| match self.tokenizer.get_token(v) { + Some(token) => Ok(token), + None => anyhow::bail!("cannot find the {v} token"), + }; + let bos_token = get_token("<s>")?; + let eos_token = get_token("</s>")?; + let inst_token = get_token("[INST]")?; + let end_inst_token = get_token("[/INST]")?; + let img_break = get_token("[IMG_BREAK]")?; + let img_end = get_token("[IMG_END]")?; + let start_gen = std::time::Instant::now(); + let mut pos = 0; + for index in 0..sample_len { + let logits = if index > 0 { + let context_size = if index > 0 { 1 } else { tokens.len() }; + let start_pos = tokens.len().saturating_sub(context_size); + let ctxt = &tokens[start_pos..]; + let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; + let logits = self.model.language_model.forward(&input, pos)?; + pos += context_size; + logits + } else { + let (_b, _c, h, w) = self.image.dims4()?; + let h = h / self.model.patch_size; + let w = w / self.model.patch_size; + let image_embeds = self.model.vision_tower.forward(&self.image)?; + let image_embeds = self.model.multi_modal_projector.forward(&image_embeds)?; + println!("generated image embeddings {image_embeds:?}"); + let image_embeds = image_embeds.to_dtype(self.model.dtype)?; + for &t in tokens.iter() { + if let Some(t) = self.tokenizer.next_token(t)? { + print!("{t}") + } + } + std::io::stdout().flush()?; + + let break_embeds = { + let input = Tensor::new(&[img_break], &self.device)?.unsqueeze(0)?; + self.model.language_model.embed_tokens().forward(&input)? + }; + let start_embeds = { + let mut in_tokens = vec![bos_token, inst_token]; + in_tokens.extend_from_slice(tokens.as_slice()); + let input = Tensor::new(in_tokens.as_slice(), &self.device)?.unsqueeze(0)?; + self.model.language_model.embed_tokens().forward(&input)? + }; + let end_embeds = { + let input = + Tensor::new(&[img_end, end_inst_token], &self.device)?.unsqueeze(0)?; + self.model.language_model.embed_tokens().forward(&input)? + }; + let mut input_embeds = vec![start_embeds]; + for h_idx in 0..h { + if h_idx > 0 { + input_embeds.push(break_embeds.clone()) + } + let row = image_embeds.narrow(1, h_idx * w, w)?; + input_embeds.push(row); + } + input_embeds.push(end_embeds); + + let input_embeds = Tensor::cat(&input_embeds, 1)?; + let logits = self + .model + .language_model + .forward_embeds(&input_embeds, None, pos)?; + pos += input_embeds.dim(1)?; + logits + }; + let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; + let logits = if self.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(self.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + self.repeat_penalty, + &tokens[start_at..], + )? + }; + + let next_token = self.logits_processor.sample(&logits)?; + tokens.push(next_token); + generated_tokens += 1; + if next_token == eos_token { + break; + } + if let Some(t) = self.tokenizer.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + } + let dt = start_gen.elapsed(); + if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } + std::io::stdout().flush()?; + println!( + "\n{generated_tokens} tokens generated ({:.2} token/s)", + generated_tokens as f64 / dt.as_secs_f64(), + ); + Ok(()) + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long, default_value = "Describe the image.\n")] + prompt: String, + + /// The temperature used to generate samples. + #[arg(long)] + temperature: Option<f64>, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option<f64>, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The length of the sample to generate (in tokens). + #[arg(long, short = 'n', default_value_t = 10000)] + sample_len: usize, + + #[arg(long)] + model_id: Option<String>, + + #[arg(long, default_value = "main")] + revision: String, + + #[arg(long)] + tokenizer_file: Option<String>, + + #[arg(long)] + config_file: Option<String>, + + #[arg(long)] + weight_files: Option<String>, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, + + #[arg(long)] + image: String, + + #[arg(long)] + vision_only: bool, +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature.unwrap_or(0.), + args.repeat_penalty, + args.repeat_last_n + ); + + let start = std::time::Instant::now(); + let api = Api::new()?; + let model_id = match &args.model_id { + Some(model_id) => model_id.to_string(), + None => "mistral-community/pixtral-12b".to_string(), + }; + let repo = api.repo(Repo::with_revision( + model_id, + RepoType::Model, + args.revision, + )); + let tokenizer_filename = match args.tokenizer_file { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("tokenizer.json")?, + }; + let filenames = match args.weight_files { + Some(files) => files + .split(',') + .map(std::path::PathBuf::from) + .collect::<Vec<_>>(), + None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, + }; + println!("retrieved the files in {:?}", start.elapsed()); + + let device = candle_examples::device(args.cpu)?; + let dtype = if device.supports_bf16() && !args.vision_only { + DType::BF16 + } else { + DType::F32 + }; + let config: Config = match args.config_file { + Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?, + None => { + let config_file = repo.get("config.json")?; + serde_json::from_slice(&std::fs::read(config_file)?)? + } + }; + let image = if args.image.ends_with(".safetensors") { + match candle::safetensors::load(&args.image, &device)?.remove("img") { + None => anyhow::bail!("no img tensor in {}", args.image), + Some(v) => v, + } + } else { + candle_examples::imagenet::load_image_with_std_mean( + &args.image, + 1024, + &[0.48145466, 0.4578275, 0.40821073], + &[0.26862954, 0.261_302_6, 0.275_777_1], + )? + }; + let image = image.to_device(&device)?.unsqueeze(0)?; + println!("loaded image with shape {:?}", image); + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; + + if args.vision_only { + let start = std::time::Instant::now(); + let model = vision_model::Model::new(&config.vision_config, vb.pp("vision_tower"))?; + println!("loaded the model in {:?}", start.elapsed()); + let embs = model.forward(&image)?; + println!("EMBS\n{embs}"); + } else { + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + let start = std::time::Instant::now(); + let model = Model::new(&config, vb)?; + println!("loaded the model in {:?}", start.elapsed()); + let mut pipeline = TextGeneration::new( + model, + image, + tokenizer, + args.seed, + args.temperature, + args.top_p, + args.repeat_penalty, + args.repeat_last_n, + &device, + ); + pipeline.run(&args.prompt, args.sample_len)?; + } + + Ok(()) +} |