summaryrefslogtreecommitdiff
path: root/candle-transformers/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src')
-rw-r--r--candle-transformers/src/models/colpali.rs42
-rw-r--r--candle-transformers/src/models/gemma.rs16
-rw-r--r--candle-transformers/src/models/mod.rs1
-rw-r--r--candle-transformers/src/models/paligemma.rs45
4 files changed, 103 insertions, 1 deletions
diff --git a/candle-transformers/src/models/colpali.rs b/candle-transformers/src/models/colpali.rs
new file mode 100644
index 00000000..1299b0a4
--- /dev/null
+++ b/candle-transformers/src/models/colpali.rs
@@ -0,0 +1,42 @@
+use candle::{Module, Result, Tensor};
+use candle_nn::VarBuilder;
+
+use super::paligemma;
+use candle_nn::{linear, Linear};
+
+pub struct Model {
+ pub model: paligemma::Model,
+ pub custom_text_projection: Linear,
+}
+
+impl Model {
+ pub fn new(config: &paligemma::Config, vb: VarBuilder) -> Result<Self> {
+ let model = paligemma::Model::new(config, vb.pp("model"))?;
+ let custom_text_projection = linear(
+ config.text_config.hidden_size,
+ 128,
+ vb.pp("custom_text_proj"),
+ )?;
+
+ Ok(Self {
+ model,
+ custom_text_projection,
+ })
+ }
+
+ pub fn forward_images(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> {
+ let outputs = self
+ .model
+ .setup_without_projection(pixel_values, input_ids)?;
+ let outputs = self.custom_text_projection.forward(&outputs)?;
+ let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
+ Ok(outputs)
+ }
+
+ pub fn forward_text(&mut self, input_ids: &Tensor) -> Result<Tensor> {
+ let outputs = self.model.forward_without_projection(input_ids)?;
+ let outputs = self.custom_text_projection.forward(&outputs)?;
+ let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
+ Ok(outputs)
+ }
+}
diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs
index 69e22678..c22a3948 100644
--- a/candle-transformers/src/models/gemma.rs
+++ b/candle-transformers/src/models/gemma.rs
@@ -403,7 +403,6 @@ impl Model {
.apply(&self.norm)?
.apply(&self.lm_head)
}
-
pub fn forward_embeds(
&mut self,
xs: &Tensor,
@@ -420,6 +419,21 @@ impl Model {
.apply(&self.lm_head)
}
+ // Forward the model and return the hidden states without the lm_head
+ pub fn forward_embeds_without_projection(
+ &mut self,
+ xs: &Tensor,
+ attn_mask: Option<&Tensor>,
+ seqlen_offset: usize,
+ ) -> Result<Tensor> {
+ let (_, _, _) = xs.dims3()?;
+ let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
+ for layer in self.layers.iter_mut() {
+ xs = layer.forward(&xs, attn_mask, seqlen_offset)?
+ }
+ Ok(xs)
+ }
+
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index 09876503..80cd4f81 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -7,6 +7,7 @@ pub mod blip_text;
pub mod chatglm;
pub mod clip;
pub mod codegeex4_9b;
+pub mod colpali;
pub mod convmixer;
pub mod convnext;
pub mod dac;
diff --git a/candle-transformers/src/models/paligemma.rs b/candle-transformers/src/models/paligemma.rs
index e22ab241..a5e7f694 100644
--- a/candle-transformers/src/models/paligemma.rs
+++ b/candle-transformers/src/models/paligemma.rs
@@ -33,6 +33,29 @@ impl Config {
projection_dim: 2048,
}
}
+
+ pub fn paligemma_3b_448() -> Self {
+ Self {
+ vision_config: siglip::VisionConfig::paligemma_3b_448(),
+ text_config: gemma::Config {
+ hidden_size: 2048,
+ intermediate_size: 16384,
+ num_attention_heads: 8,
+ num_hidden_layers: 18,
+ num_key_value_heads: 1,
+ // Default values.
+ rope_theta: 10000.,
+ head_dim: 256,
+ hidden_act: Some(candle_nn::Activation::GeluPytorchTanh),
+ hidden_activation: None,
+ attention_bias: false,
+ max_position_embeddings: 8192,
+ rms_norm_eps: 1e-6,
+ vocab_size: 257216,
+ },
+ projection_dim: 2048,
+ }
+ }
}
#[derive(Clone, Debug)]
@@ -102,6 +125,28 @@ impl Model {
self.language_model.forward(input_ids, pos)
}
+ pub fn forward_without_projection(&mut self, input_ids: &Tensor) -> Result<Tensor> {
+ self.clear_kv_cache();
+ let input_embeds = self.language_model.embed_tokens().forward(input_ids)?;
+ self.language_model
+ .forward_embeds_without_projection(&input_embeds, None, 0)
+ }
+ pub fn setup_without_projection(
+ &mut self,
+ pixel_values: &Tensor,
+ input_ids: &Tensor,
+ ) -> Result<Tensor> {
+ self.clear_kv_cache();
+ let image_features = self
+ .vision_tower
+ .forward(pixel_values)?
+ .apply(&self.multi_modal_projector)?;
+ let image_features = crate::models::clip::div_l2_norm(&image_features)?;
+ let text_features = self.language_model.embed_tokens().forward(input_ids)?;
+ let input_embeds = Tensor::cat(&[image_features, text_features], 1)?;
+ self.language_model
+ .forward_embeds_without_projection(&input_embeds, None, 0)
+ }
pub fn clear_kv_cache(&mut self) {
self.pos = 0;
self.language_model.clear_kv_cache()