summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/Cargo.toml5
-rw-r--r--candle-examples/examples/colpali/README.md18
-rw-r--r--candle-examples/examples/colpali/main.rs268
-rw-r--r--candle-transformers/src/models/colpali.rs42
-rw-r--r--candle-transformers/src/models/gemma.rs16
-rw-r--r--candle-transformers/src/models/mod.rs1
-rw-r--r--candle-transformers/src/models/paligemma.rs45
7 files changed, 394 insertions, 1 deletions
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml
index 2c96f87d..4edde7a9 100644
--- a/candle-examples/Cargo.toml
+++ b/candle-examples/Cargo.toml
@@ -36,6 +36,7 @@ serde_json = { workspace = true }
symphonia = { version = "0.5.3", features = ["all"], optional = true }
tokenizers = { workspace = true, features = ["onig"] }
cpal = { version = "0.15.2", optional = true }
+pdf2image = { version = "0.1.2" , optional = true}
[dev-dependencies]
anyhow = { workspace = true }
@@ -117,3 +118,7 @@ required-features = ["depth_anything_v2"]
[[example]]
name = "silero-vad"
required-features = ["onnx"]
+
+[[example]]
+name = "colpali"
+required-features = ["pdf2image"] \ No newline at end of file
diff --git a/candle-examples/examples/colpali/README.md b/candle-examples/examples/colpali/README.md
new file mode 100644
index 00000000..e6a55798
--- /dev/null
+++ b/candle-examples/examples/colpali/README.md
@@ -0,0 +1,18 @@
+# Colpali
+
+[HuggingFace Model Card](https://huggingface.co/vidore/colpali-v1.2-merged)
+
+```
+wget https://arxiv.org/pdf/1706.03762.pdf
+cargo run --features cuda,pdf2image --release --example colpali -- --prompt "What is Positional Encoding" --pdf "1706.03762.pdf"
+```
+
+```
+Prompt: what is position encoding?
+top 3 page numbers that contain similarity to the prompt
+-----------------------------------
+Page: 6
+Page: 11
+Page: 15
+-----------------------------------
+``` \ No newline at end of file
diff --git a/candle-examples/examples/colpali/main.rs b/candle-examples/examples/colpali/main.rs
new file mode 100644
index 00000000..2a1cc96b
--- /dev/null
+++ b/candle-examples/examples/colpali/main.rs
@@ -0,0 +1,268 @@
+use anyhow::{Error as E, Result};
+use candle::{DType, Device, Tensor};
+use candle_nn::VarBuilder;
+use candle_transformers::models::colpali::Model;
+use candle_transformers::models::{colpali, paligemma};
+use clap::Parser;
+use hf_hub::{api::sync::Api, Repo, RepoType};
+use image::DynamicImage;
+use pdf2image::{RenderOptionsBuilder, PDF};
+use tokenizers::Tokenizer;
+
+struct PageRetriever {
+ model: Model,
+ config: paligemma::Config,
+ pdf: PDF,
+ device: Device,
+ tokenizer: Tokenizer,
+ range: pdf2image::Pages,
+ batch_size: usize,
+ top_k: usize,
+}
+
+impl PageRetriever {
+ fn new(
+ model: Model,
+ config: paligemma::Config,
+ pdf: PDF,
+ tokenizer: Tokenizer,
+ device: &Device,
+ range: Option<pdf2image::Pages>,
+ batch_size: usize,
+ top_k: usize,
+ ) -> Self {
+ let page_count = pdf.page_count();
+ Self {
+ model,
+ config,
+ pdf,
+ device: device.clone(),
+ tokenizer,
+ range: range.unwrap_or_else(|| pdf2image::Pages::Range(1..=page_count)),
+ batch_size,
+ top_k,
+ }
+ }
+
+ fn get_images_from_pdf(&self) -> Result<Vec<DynamicImage>> {
+ let pages = self
+ .pdf
+ .render(self.range.clone(), RenderOptionsBuilder::default().build()?)?;
+ Ok(pages)
+ }
+
+ fn tokenize_batch(&self, prompts: Vec<&str>) -> Result<Tensor> {
+ let tokens = self.tokenizer.encode_batch(prompts, true).map_err(E::msg)?;
+ let token_ids = tokens
+ .iter()
+ .map(|tokens| {
+ let tokens = tokens.get_ids().to_vec();
+ Tensor::new(tokens.as_slice(), &self.device)
+ })
+ .collect::<candle::Result<Vec<_>>>()?;
+ let input = Tensor::stack(&token_ids, 0)?;
+ Ok(input)
+ }
+
+ fn images_to_tensor(
+ &self,
+ pages: &[DynamicImage],
+ image_size: usize,
+ ) -> anyhow::Result<Tensor> {
+ let mut images = vec![];
+ for page in pages.iter() {
+ let img = page.resize_to_fill(
+ image_size as u32,
+ image_size as u32,
+ image::imageops::FilterType::Triangle,
+ );
+ let img = img.to_rgb8();
+ let img = img.into_raw();
+ let img = Tensor::from_vec(img, (image_size, image_size, 3), &Device::Cpu)?
+ .permute((2, 0, 1))?
+ .to_dtype(DType::F32)?
+ .affine(2. / 255., -1.)?;
+ images.push(img);
+ }
+ let images = Tensor::stack(&images, 0)?;
+ Ok(images)
+ }
+
+ fn retrieve(&mut self, prompt: &str) -> Result<Vec<usize>> {
+ let dtype = if self.device.is_cuda() {
+ DType::BF16
+ } else {
+ DType::F32
+ };
+
+ let dummy_prompt: &str = "Describe the image";
+
+ let input = self.tokenize_batch(vec![prompt])?;
+ let dummy_input = self.tokenize_batch(vec![dummy_prompt])?;
+
+ let pages = self.get_images_from_pdf()?;
+ let mut all_scores = Vec::new();
+ for batch in pages.chunks(self.batch_size) {
+ let page_images = self
+ .images_to_tensor(batch, self.config.vision_config.image_size)?
+ .to_device(&self.device)?
+ .to_dtype(dtype)?;
+ let dummy_input = dummy_input.repeat((page_images.dims()[0], 0))?;
+
+ let image_embeddings = self.model.forward_images(&page_images, &dummy_input)?;
+ let text_embeddings = self.model.forward_text(&input)?;
+
+ let scores = text_embeddings
+ .unsqueeze(1)?
+ .broadcast_matmul(&image_embeddings.unsqueeze(0)?.transpose(3, 2)?)?
+ .max(3)?
+ .sum(2)?;
+ let batch_scores: Vec<f32> = scores
+ .to_dtype(DType::F32)?
+ .to_vec2()?
+ .into_iter()
+ .flatten()
+ .collect();
+ all_scores.extend(batch_scores);
+ }
+
+ let mut indices: Vec<usize> = (0..all_scores.len()).collect();
+ indices.sort_by(|a, b| all_scores[*b].partial_cmp(&all_scores[*a]).unwrap());
+
+ let top_k_indices = indices[0..self.top_k].to_vec();
+
+ Ok(top_k_indices)
+ }
+}
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
+ #[arg(long)]
+ prompt: String,
+
+ /// number of top pages to show.
+ #[arg(long, default_value_t = 3)]
+ top_k: usize,
+
+ #[arg(long)]
+ model_id: Option<String>,
+
+ #[arg(long, default_value = "main")]
+ revision: String,
+
+ #[arg(long)]
+ tokenizer_file: Option<String>,
+
+ #[arg(long)]
+ weight_files: Option<String>,
+
+ #[arg(long)]
+ pdf: String,
+
+ #[arg(long)]
+ start: Option<u32>,
+
+ #[arg(long)]
+ end: Option<u32>,
+}
+
+fn main() -> Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
+ let args = Args::parse();
+ let _guard = if args.tracing {
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
+ println!(
+ "avx: {}, neon: {}, simd128: {}, f16c: {}",
+ candle::utils::with_avx(),
+ candle::utils::with_neon(),
+ candle::utils::with_simd128(),
+ candle::utils::with_f16c()
+ );
+
+ let api = Api::new()?;
+ let model_id = match &args.model_id {
+ Some(model_id) => model_id.to_string(),
+ None => "vidore/colpali-v1.2-merged".to_string(),
+ };
+ let repo = api.repo(Repo::with_revision(
+ model_id,
+ RepoType::Model,
+ args.revision,
+ ));
+
+ let tokenizer_filename = match args.tokenizer_file {
+ Some(file) => std::path::PathBuf::from(file),
+ None => api
+ .repo(Repo::with_revision(
+ "vidore/colpali".to_string(),
+ RepoType::Model,
+ "main".to_string(),
+ ))
+ .get("tokenizer.json")?,
+ };
+
+ let filenames = match args.weight_files {
+ Some(files) => files
+ .split(',')
+ .map(std::path::PathBuf::from)
+ .collect::<Vec<_>>(),
+ None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
+ };
+
+ let start = std::time::Instant::now();
+
+ let config: paligemma::Config = paligemma::Config::paligemma_3b_448();
+
+ println!("retrieved the files in {:?}", start.elapsed());
+
+ let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+ let device = candle_examples::device(false)?;
+ let dtype = if device.is_cuda() {
+ DType::BF16
+ } else {
+ DType::F32
+ };
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
+ let model = colpali::Model::new(&config, vb)?;
+
+ let pdf = PDF::from_file(args.pdf)?;
+
+ // check if start and end given in arg
+ let range = if let (Some(start), Some(end)) = (args.start, args.end) {
+ pdf2image::Pages::Range(start..=end)
+ } else {
+ pdf2image::Pages::Range(1..=pdf.page_count()) // can use pdf2image::Pages::All but there is a bug in the library which causes the first page to rendered twice.
+ };
+
+ let mut retriever =
+ PageRetriever::new(model, config, pdf, tokenizer, &device, Some(range), 4, 3);
+ let top_k_indices = retriever.retrieve(&args.prompt)?;
+
+ println!("Prompt: {}", args.prompt);
+ println!(
+ "top {} page numbers that contain similarity to the prompt",
+ retriever.top_k
+ );
+ println!("-----------------------------------");
+ for index in top_k_indices {
+ println!("Page: {:?}", index + 1);
+ }
+ println!("-----------------------------------");
+ Ok(())
+}
diff --git a/candle-transformers/src/models/colpali.rs b/candle-transformers/src/models/colpali.rs
new file mode 100644
index 00000000..1299b0a4
--- /dev/null
+++ b/candle-transformers/src/models/colpali.rs
@@ -0,0 +1,42 @@
+use candle::{Module, Result, Tensor};
+use candle_nn::VarBuilder;
+
+use super::paligemma;
+use candle_nn::{linear, Linear};
+
+pub struct Model {
+ pub model: paligemma::Model,
+ pub custom_text_projection: Linear,
+}
+
+impl Model {
+ pub fn new(config: &paligemma::Config, vb: VarBuilder) -> Result<Self> {
+ let model = paligemma::Model::new(config, vb.pp("model"))?;
+ let custom_text_projection = linear(
+ config.text_config.hidden_size,
+ 128,
+ vb.pp("custom_text_proj"),
+ )?;
+
+ Ok(Self {
+ model,
+ custom_text_projection,
+ })
+ }
+
+ pub fn forward_images(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> {
+ let outputs = self
+ .model
+ .setup_without_projection(pixel_values, input_ids)?;
+ let outputs = self.custom_text_projection.forward(&outputs)?;
+ let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
+ Ok(outputs)
+ }
+
+ pub fn forward_text(&mut self, input_ids: &Tensor) -> Result<Tensor> {
+ let outputs = self.model.forward_without_projection(input_ids)?;
+ let outputs = self.custom_text_projection.forward(&outputs)?;
+ let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
+ Ok(outputs)
+ }
+}
diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs
index 69e22678..c22a3948 100644
--- a/candle-transformers/src/models/gemma.rs
+++ b/candle-transformers/src/models/gemma.rs
@@ -403,7 +403,6 @@ impl Model {
.apply(&self.norm)?
.apply(&self.lm_head)
}
-
pub fn forward_embeds(
&mut self,
xs: &Tensor,
@@ -420,6 +419,21 @@ impl Model {
.apply(&self.lm_head)
}
+ // Forward the model and return the hidden states without the lm_head
+ pub fn forward_embeds_without_projection(
+ &mut self,
+ xs: &Tensor,
+ attn_mask: Option<&Tensor>,
+ seqlen_offset: usize,
+ ) -> Result<Tensor> {
+ let (_, _, _) = xs.dims3()?;
+ let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
+ for layer in self.layers.iter_mut() {
+ xs = layer.forward(&xs, attn_mask, seqlen_offset)?
+ }
+ Ok(xs)
+ }
+
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index 09876503..80cd4f81 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -7,6 +7,7 @@ pub mod blip_text;
pub mod chatglm;
pub mod clip;
pub mod codegeex4_9b;
+pub mod colpali;
pub mod convmixer;
pub mod convnext;
pub mod dac;
diff --git a/candle-transformers/src/models/paligemma.rs b/candle-transformers/src/models/paligemma.rs
index e22ab241..a5e7f694 100644
--- a/candle-transformers/src/models/paligemma.rs
+++ b/candle-transformers/src/models/paligemma.rs
@@ -33,6 +33,29 @@ impl Config {
projection_dim: 2048,
}
}
+
+ pub fn paligemma_3b_448() -> Self {
+ Self {
+ vision_config: siglip::VisionConfig::paligemma_3b_448(),
+ text_config: gemma::Config {
+ hidden_size: 2048,
+ intermediate_size: 16384,
+ num_attention_heads: 8,
+ num_hidden_layers: 18,
+ num_key_value_heads: 1,
+ // Default values.
+ rope_theta: 10000.,
+ head_dim: 256,
+ hidden_act: Some(candle_nn::Activation::GeluPytorchTanh),
+ hidden_activation: None,
+ attention_bias: false,
+ max_position_embeddings: 8192,
+ rms_norm_eps: 1e-6,
+ vocab_size: 257216,
+ },
+ projection_dim: 2048,
+ }
+ }
}
#[derive(Clone, Debug)]
@@ -102,6 +125,28 @@ impl Model {
self.language_model.forward(input_ids, pos)
}
+ pub fn forward_without_projection(&mut self, input_ids: &Tensor) -> Result<Tensor> {
+ self.clear_kv_cache();
+ let input_embeds = self.language_model.embed_tokens().forward(input_ids)?;
+ self.language_model
+ .forward_embeds_without_projection(&input_embeds, None, 0)
+ }
+ pub fn setup_without_projection(
+ &mut self,
+ pixel_values: &Tensor,
+ input_ids: &Tensor,
+ ) -> Result<Tensor> {
+ self.clear_kv_cache();
+ let image_features = self
+ .vision_tower
+ .forward(pixel_values)?
+ .apply(&self.multi_modal_projector)?;
+ let image_features = crate::models::clip::div_l2_norm(&image_features)?;
+ let text_features = self.language_model.embed_tokens().forward(input_ids)?;
+ let input_embeds = Tensor::cat(&[image_features, text_features], 1)?;
+ self.language_model
+ .forward_embeds_without_projection(&input_embeds, None, 0)
+ }
pub fn clear_kv_cache(&mut self) {
self.pos = 0;
self.language_model.clear_kv_cache()