summaryrefslogtreecommitdiff
path: root/candle-examples/examples
diff options
context:
space:
mode:
authorAkshay Ballal <61191840+akshayballal95@users.noreply.github.com>2024-10-01 11:48:39 +0200
committerGitHub <noreply@github.com>2024-10-01 11:48:39 +0200
commit888d886dd8d5cac2558064060c59a4b51b6aa530 (patch)
tree7bf0848bc3211453b7e07b26edf5c108e45dc7cf /candle-examples/examples
parent6110ad8d4ff8272bdd10687eae4edee59a07b517 (diff)
downloadcandle-888d886dd8d5cac2558064060c59a4b51b6aa530.tar.gz
candle-888d886dd8d5cac2558064060c59a4b51b6aa530.tar.bz2
candle-888d886dd8d5cac2558064060c59a4b51b6aa530.zip
Add ColPali (#2524)
* add colpali * cleanup * fix clippy
Diffstat (limited to 'candle-examples/examples')
-rw-r--r--candle-examples/examples/colpali/README.md18
-rw-r--r--candle-examples/examples/colpali/main.rs268
2 files changed, 286 insertions, 0 deletions
diff --git a/candle-examples/examples/colpali/README.md b/candle-examples/examples/colpali/README.md
new file mode 100644
index 00000000..e6a55798
--- /dev/null
+++ b/candle-examples/examples/colpali/README.md
@@ -0,0 +1,18 @@
+# Colpali
+
+[HuggingFace Model Card](https://huggingface.co/vidore/colpali-v1.2-merged)
+
+```
+wget https://arxiv.org/pdf/1706.03762.pdf
+cargo run --features cuda,pdf2image --release --example colpali -- --prompt "What is Positional Encoding" --pdf "1706.03762.pdf"
+```
+
+```
+Prompt: what is position encoding?
+top 3 page numbers that contain similarity to the prompt
+-----------------------------------
+Page: 6
+Page: 11
+Page: 15
+-----------------------------------
+``` \ No newline at end of file
diff --git a/candle-examples/examples/colpali/main.rs b/candle-examples/examples/colpali/main.rs
new file mode 100644
index 00000000..2a1cc96b
--- /dev/null
+++ b/candle-examples/examples/colpali/main.rs
@@ -0,0 +1,268 @@
+use anyhow::{Error as E, Result};
+use candle::{DType, Device, Tensor};
+use candle_nn::VarBuilder;
+use candle_transformers::models::colpali::Model;
+use candle_transformers::models::{colpali, paligemma};
+use clap::Parser;
+use hf_hub::{api::sync::Api, Repo, RepoType};
+use image::DynamicImage;
+use pdf2image::{RenderOptionsBuilder, PDF};
+use tokenizers::Tokenizer;
+
+struct PageRetriever {
+ model: Model,
+ config: paligemma::Config,
+ pdf: PDF,
+ device: Device,
+ tokenizer: Tokenizer,
+ range: pdf2image::Pages,
+ batch_size: usize,
+ top_k: usize,
+}
+
+impl PageRetriever {
+ fn new(
+ model: Model,
+ config: paligemma::Config,
+ pdf: PDF,
+ tokenizer: Tokenizer,
+ device: &Device,
+ range: Option<pdf2image::Pages>,
+ batch_size: usize,
+ top_k: usize,
+ ) -> Self {
+ let page_count = pdf.page_count();
+ Self {
+ model,
+ config,
+ pdf,
+ device: device.clone(),
+ tokenizer,
+ range: range.unwrap_or_else(|| pdf2image::Pages::Range(1..=page_count)),
+ batch_size,
+ top_k,
+ }
+ }
+
+ fn get_images_from_pdf(&self) -> Result<Vec<DynamicImage>> {
+ let pages = self
+ .pdf
+ .render(self.range.clone(), RenderOptionsBuilder::default().build()?)?;
+ Ok(pages)
+ }
+
+ fn tokenize_batch(&self, prompts: Vec<&str>) -> Result<Tensor> {
+ let tokens = self.tokenizer.encode_batch(prompts, true).map_err(E::msg)?;
+ let token_ids = tokens
+ .iter()
+ .map(|tokens| {
+ let tokens = tokens.get_ids().to_vec();
+ Tensor::new(tokens.as_slice(), &self.device)
+ })
+ .collect::<candle::Result<Vec<_>>>()?;
+ let input = Tensor::stack(&token_ids, 0)?;
+ Ok(input)
+ }
+
+ fn images_to_tensor(
+ &self,
+ pages: &[DynamicImage],
+ image_size: usize,
+ ) -> anyhow::Result<Tensor> {
+ let mut images = vec![];
+ for page in pages.iter() {
+ let img = page.resize_to_fill(
+ image_size as u32,
+ image_size as u32,
+ image::imageops::FilterType::Triangle,
+ );
+ let img = img.to_rgb8();
+ let img = img.into_raw();
+ let img = Tensor::from_vec(img, (image_size, image_size, 3), &Device::Cpu)?
+ .permute((2, 0, 1))?
+ .to_dtype(DType::F32)?
+ .affine(2. / 255., -1.)?;
+ images.push(img);
+ }
+ let images = Tensor::stack(&images, 0)?;
+ Ok(images)
+ }
+
+ fn retrieve(&mut self, prompt: &str) -> Result<Vec<usize>> {
+ let dtype = if self.device.is_cuda() {
+ DType::BF16
+ } else {
+ DType::F32
+ };
+
+ let dummy_prompt: &str = "Describe the image";
+
+ let input = self.tokenize_batch(vec![prompt])?;
+ let dummy_input = self.tokenize_batch(vec![dummy_prompt])?;
+
+ let pages = self.get_images_from_pdf()?;
+ let mut all_scores = Vec::new();
+ for batch in pages.chunks(self.batch_size) {
+ let page_images = self
+ .images_to_tensor(batch, self.config.vision_config.image_size)?
+ .to_device(&self.device)?
+ .to_dtype(dtype)?;
+ let dummy_input = dummy_input.repeat((page_images.dims()[0], 0))?;
+
+ let image_embeddings = self.model.forward_images(&page_images, &dummy_input)?;
+ let text_embeddings = self.model.forward_text(&input)?;
+
+ let scores = text_embeddings
+ .unsqueeze(1)?
+ .broadcast_matmul(&image_embeddings.unsqueeze(0)?.transpose(3, 2)?)?
+ .max(3)?
+ .sum(2)?;
+ let batch_scores: Vec<f32> = scores
+ .to_dtype(DType::F32)?
+ .to_vec2()?
+ .into_iter()
+ .flatten()
+ .collect();
+ all_scores.extend(batch_scores);
+ }
+
+ let mut indices: Vec<usize> = (0..all_scores.len()).collect();
+ indices.sort_by(|a, b| all_scores[*b].partial_cmp(&all_scores[*a]).unwrap());
+
+ let top_k_indices = indices[0..self.top_k].to_vec();
+
+ Ok(top_k_indices)
+ }
+}
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
+ #[arg(long)]
+ prompt: String,
+
+ /// number of top pages to show.
+ #[arg(long, default_value_t = 3)]
+ top_k: usize,
+
+ #[arg(long)]
+ model_id: Option<String>,
+
+ #[arg(long, default_value = "main")]
+ revision: String,
+
+ #[arg(long)]
+ tokenizer_file: Option<String>,
+
+ #[arg(long)]
+ weight_files: Option<String>,
+
+ #[arg(long)]
+ pdf: String,
+
+ #[arg(long)]
+ start: Option<u32>,
+
+ #[arg(long)]
+ end: Option<u32>,
+}
+
+fn main() -> Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
+ let args = Args::parse();
+ let _guard = if args.tracing {
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
+ println!(
+ "avx: {}, neon: {}, simd128: {}, f16c: {}",
+ candle::utils::with_avx(),
+ candle::utils::with_neon(),
+ candle::utils::with_simd128(),
+ candle::utils::with_f16c()
+ );
+
+ let api = Api::new()?;
+ let model_id = match &args.model_id {
+ Some(model_id) => model_id.to_string(),
+ None => "vidore/colpali-v1.2-merged".to_string(),
+ };
+ let repo = api.repo(Repo::with_revision(
+ model_id,
+ RepoType::Model,
+ args.revision,
+ ));
+
+ let tokenizer_filename = match args.tokenizer_file {
+ Some(file) => std::path::PathBuf::from(file),
+ None => api
+ .repo(Repo::with_revision(
+ "vidore/colpali".to_string(),
+ RepoType::Model,
+ "main".to_string(),
+ ))
+ .get("tokenizer.json")?,
+ };
+
+ let filenames = match args.weight_files {
+ Some(files) => files
+ .split(',')
+ .map(std::path::PathBuf::from)
+ .collect::<Vec<_>>(),
+ None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
+ };
+
+ let start = std::time::Instant::now();
+
+ let config: paligemma::Config = paligemma::Config::paligemma_3b_448();
+
+ println!("retrieved the files in {:?}", start.elapsed());
+
+ let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+ let device = candle_examples::device(false)?;
+ let dtype = if device.is_cuda() {
+ DType::BF16
+ } else {
+ DType::F32
+ };
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
+ let model = colpali::Model::new(&config, vb)?;
+
+ let pdf = PDF::from_file(args.pdf)?;
+
+ // check if start and end given in arg
+ let range = if let (Some(start), Some(end)) = (args.start, args.end) {
+ pdf2image::Pages::Range(start..=end)
+ } else {
+ pdf2image::Pages::Range(1..=pdf.page_count()) // can use pdf2image::Pages::All but there is a bug in the library which causes the first page to rendered twice.
+ };
+
+ let mut retriever =
+ PageRetriever::new(model, config, pdf, tokenizer, &device, Some(range), 4, 3);
+ let top_k_indices = retriever.retrieve(&args.prompt)?;
+
+ println!("Prompt: {}", args.prompt);
+ println!(
+ "top {} page numbers that contain similarity to the prompt",
+ retriever.top_k
+ );
+ println!("-----------------------------------");
+ for index in top_k_indices {
+ println!("Page: {:?}", index + 1);
+ }
+ println!("-----------------------------------");
+ Ok(())
+}