diff options
author | Akshay Ballal <61191840+akshayballal95@users.noreply.github.com> | 2024-10-01 11:48:39 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-01 11:48:39 +0200 |
commit | 888d886dd8d5cac2558064060c59a4b51b6aa530 (patch) | |
tree | 7bf0848bc3211453b7e07b26edf5c108e45dc7cf /candle-examples/examples | |
parent | 6110ad8d4ff8272bdd10687eae4edee59a07b517 (diff) | |
download | candle-888d886dd8d5cac2558064060c59a4b51b6aa530.tar.gz candle-888d886dd8d5cac2558064060c59a4b51b6aa530.tar.bz2 candle-888d886dd8d5cac2558064060c59a4b51b6aa530.zip |
Add ColPali (#2524)
* add colpali
* cleanup
* fix clippy
Diffstat (limited to 'candle-examples/examples')
-rw-r--r-- | candle-examples/examples/colpali/README.md | 18 | ||||
-rw-r--r-- | candle-examples/examples/colpali/main.rs | 268 |
2 files changed, 286 insertions, 0 deletions
diff --git a/candle-examples/examples/colpali/README.md b/candle-examples/examples/colpali/README.md new file mode 100644 index 00000000..e6a55798 --- /dev/null +++ b/candle-examples/examples/colpali/README.md @@ -0,0 +1,18 @@ +# Colpali + +[HuggingFace Model Card](https://huggingface.co/vidore/colpali-v1.2-merged) + +``` +wget https://arxiv.org/pdf/1706.03762.pdf +cargo run --features cuda,pdf2image --release --example colpali -- --prompt "What is Positional Encoding" --pdf "1706.03762.pdf" +``` + +``` +Prompt: what is position encoding? +top 3 page numbers that contain similarity to the prompt +----------------------------------- +Page: 6 +Page: 11 +Page: 15 +----------------------------------- +```
\ No newline at end of file diff --git a/candle-examples/examples/colpali/main.rs b/candle-examples/examples/colpali/main.rs new file mode 100644 index 00000000..2a1cc96b --- /dev/null +++ b/candle-examples/examples/colpali/main.rs @@ -0,0 +1,268 @@ +use anyhow::{Error as E, Result}; +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::models::colpali::Model; +use candle_transformers::models::{colpali, paligemma}; +use clap::Parser; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use image::DynamicImage; +use pdf2image::{RenderOptionsBuilder, PDF}; +use tokenizers::Tokenizer; + +struct PageRetriever { + model: Model, + config: paligemma::Config, + pdf: PDF, + device: Device, + tokenizer: Tokenizer, + range: pdf2image::Pages, + batch_size: usize, + top_k: usize, +} + +impl PageRetriever { + fn new( + model: Model, + config: paligemma::Config, + pdf: PDF, + tokenizer: Tokenizer, + device: &Device, + range: Option<pdf2image::Pages>, + batch_size: usize, + top_k: usize, + ) -> Self { + let page_count = pdf.page_count(); + Self { + model, + config, + pdf, + device: device.clone(), + tokenizer, + range: range.unwrap_or_else(|| pdf2image::Pages::Range(1..=page_count)), + batch_size, + top_k, + } + } + + fn get_images_from_pdf(&self) -> Result<Vec<DynamicImage>> { + let pages = self + .pdf + .render(self.range.clone(), RenderOptionsBuilder::default().build()?)?; + Ok(pages) + } + + fn tokenize_batch(&self, prompts: Vec<&str>) -> Result<Tensor> { + let tokens = self.tokenizer.encode_batch(prompts, true).map_err(E::msg)?; + let token_ids = tokens + .iter() + .map(|tokens| { + let tokens = tokens.get_ids().to_vec(); + Tensor::new(tokens.as_slice(), &self.device) + }) + .collect::<candle::Result<Vec<_>>>()?; + let input = Tensor::stack(&token_ids, 0)?; + Ok(input) + } + + fn images_to_tensor( + &self, + pages: &[DynamicImage], + image_size: usize, + ) -> anyhow::Result<Tensor> { + let mut images = vec![]; + for page in pages.iter() { + let img = page.resize_to_fill( + image_size as u32, + image_size as u32, + image::imageops::FilterType::Triangle, + ); + let img = img.to_rgb8(); + let img = img.into_raw(); + let img = Tensor::from_vec(img, (image_size, image_size, 3), &Device::Cpu)? + .permute((2, 0, 1))? + .to_dtype(DType::F32)? + .affine(2. / 255., -1.)?; + images.push(img); + } + let images = Tensor::stack(&images, 0)?; + Ok(images) + } + + fn retrieve(&mut self, prompt: &str) -> Result<Vec<usize>> { + let dtype = if self.device.is_cuda() { + DType::BF16 + } else { + DType::F32 + }; + + let dummy_prompt: &str = "Describe the image"; + + let input = self.tokenize_batch(vec![prompt])?; + let dummy_input = self.tokenize_batch(vec![dummy_prompt])?; + + let pages = self.get_images_from_pdf()?; + let mut all_scores = Vec::new(); + for batch in pages.chunks(self.batch_size) { + let page_images = self + .images_to_tensor(batch, self.config.vision_config.image_size)? + .to_device(&self.device)? + .to_dtype(dtype)?; + let dummy_input = dummy_input.repeat((page_images.dims()[0], 0))?; + + let image_embeddings = self.model.forward_images(&page_images, &dummy_input)?; + let text_embeddings = self.model.forward_text(&input)?; + + let scores = text_embeddings + .unsqueeze(1)? + .broadcast_matmul(&image_embeddings.unsqueeze(0)?.transpose(3, 2)?)? + .max(3)? + .sum(2)?; + let batch_scores: Vec<f32> = scores + .to_dtype(DType::F32)? + .to_vec2()? + .into_iter() + .flatten() + .collect(); + all_scores.extend(batch_scores); + } + + let mut indices: Vec<usize> = (0..all_scores.len()).collect(); + indices.sort_by(|a, b| all_scores[*b].partial_cmp(&all_scores[*a]).unwrap()); + + let top_k_indices = indices[0..self.top_k].to_vec(); + + Ok(top_k_indices) + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long)] + prompt: String, + + /// number of top pages to show. + #[arg(long, default_value_t = 3)] + top_k: usize, + + #[arg(long)] + model_id: Option<String>, + + #[arg(long, default_value = "main")] + revision: String, + + #[arg(long)] + tokenizer_file: Option<String>, + + #[arg(long)] + weight_files: Option<String>, + + #[arg(long)] + pdf: String, + + #[arg(long)] + start: Option<u32>, + + #[arg(long)] + end: Option<u32>, +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + + let api = Api::new()?; + let model_id = match &args.model_id { + Some(model_id) => model_id.to_string(), + None => "vidore/colpali-v1.2-merged".to_string(), + }; + let repo = api.repo(Repo::with_revision( + model_id, + RepoType::Model, + args.revision, + )); + + let tokenizer_filename = match args.tokenizer_file { + Some(file) => std::path::PathBuf::from(file), + None => api + .repo(Repo::with_revision( + "vidore/colpali".to_string(), + RepoType::Model, + "main".to_string(), + )) + .get("tokenizer.json")?, + }; + + let filenames = match args.weight_files { + Some(files) => files + .split(',') + .map(std::path::PathBuf::from) + .collect::<Vec<_>>(), + None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, + }; + + let start = std::time::Instant::now(); + + let config: paligemma::Config = paligemma::Config::paligemma_3b_448(); + + println!("retrieved the files in {:?}", start.elapsed()); + + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + let device = candle_examples::device(false)?; + let dtype = if device.is_cuda() { + DType::BF16 + } else { + DType::F32 + }; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; + let model = colpali::Model::new(&config, vb)?; + + let pdf = PDF::from_file(args.pdf)?; + + // check if start and end given in arg + let range = if let (Some(start), Some(end)) = (args.start, args.end) { + pdf2image::Pages::Range(start..=end) + } else { + pdf2image::Pages::Range(1..=pdf.page_count()) // can use pdf2image::Pages::All but there is a bug in the library which causes the first page to rendered twice. + }; + + let mut retriever = + PageRetriever::new(model, config, pdf, tokenizer, &device, Some(range), 4, 3); + let top_k_indices = retriever.retrieve(&args.prompt)?; + + println!("Prompt: {}", args.prompt); + println!( + "top {} page numbers that contain similarity to the prompt", + retriever.top_k + ); + println!("-----------------------------------"); + for index in top_k_indices { + println!("Page: {:?}", index + 1); + } + println!("-----------------------------------"); + Ok(()) +} |