diff options
Diffstat (limited to 'candle-transformers/src')
-rw-r--r-- | candle-transformers/src/models/colpali.rs | 42 | ||||
-rw-r--r-- | candle-transformers/src/models/gemma.rs | 16 | ||||
-rw-r--r-- | candle-transformers/src/models/mod.rs | 1 | ||||
-rw-r--r-- | candle-transformers/src/models/paligemma.rs | 45 |
4 files changed, 103 insertions, 1 deletions
diff --git a/candle-transformers/src/models/colpali.rs b/candle-transformers/src/models/colpali.rs new file mode 100644 index 00000000..1299b0a4 --- /dev/null +++ b/candle-transformers/src/models/colpali.rs @@ -0,0 +1,42 @@ +use candle::{Module, Result, Tensor}; +use candle_nn::VarBuilder; + +use super::paligemma; +use candle_nn::{linear, Linear}; + +pub struct Model { + pub model: paligemma::Model, + pub custom_text_projection: Linear, +} + +impl Model { + pub fn new(config: &paligemma::Config, vb: VarBuilder) -> Result<Self> { + let model = paligemma::Model::new(config, vb.pp("model"))?; + let custom_text_projection = linear( + config.text_config.hidden_size, + 128, + vb.pp("custom_text_proj"), + )?; + + Ok(Self { + model, + custom_text_projection, + }) + } + + pub fn forward_images(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> { + let outputs = self + .model + .setup_without_projection(pixel_values, input_ids)?; + let outputs = self.custom_text_projection.forward(&outputs)?; + let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?; + Ok(outputs) + } + + pub fn forward_text(&mut self, input_ids: &Tensor) -> Result<Tensor> { + let outputs = self.model.forward_without_projection(input_ids)?; + let outputs = self.custom_text_projection.forward(&outputs)?; + let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?; + Ok(outputs) + } +} diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs index 69e22678..c22a3948 100644 --- a/candle-transformers/src/models/gemma.rs +++ b/candle-transformers/src/models/gemma.rs @@ -403,7 +403,6 @@ impl Model { .apply(&self.norm)? .apply(&self.lm_head) } - pub fn forward_embeds( &mut self, xs: &Tensor, @@ -420,6 +419,21 @@ impl Model { .apply(&self.lm_head) } + // Forward the model and return the hidden states without the lm_head + pub fn forward_embeds_without_projection( + &mut self, + xs: &Tensor, + attn_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result<Tensor> { + let (_, _, _) = xs.dims3()?; + let mut xs = (xs * (self.hidden_size as f64).sqrt())?; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attn_mask, seqlen_offset)? + } + Ok(xs) + } + pub fn clear_kv_cache(&mut self) { for layer in self.layers.iter_mut() { layer.clear_kv_cache() diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 09876503..80cd4f81 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -7,6 +7,7 @@ pub mod blip_text; pub mod chatglm; pub mod clip; pub mod codegeex4_9b; +pub mod colpali; pub mod convmixer; pub mod convnext; pub mod dac; diff --git a/candle-transformers/src/models/paligemma.rs b/candle-transformers/src/models/paligemma.rs index e22ab241..a5e7f694 100644 --- a/candle-transformers/src/models/paligemma.rs +++ b/candle-transformers/src/models/paligemma.rs @@ -33,6 +33,29 @@ impl Config { projection_dim: 2048, } } + + pub fn paligemma_3b_448() -> Self { + Self { + vision_config: siglip::VisionConfig::paligemma_3b_448(), + text_config: gemma::Config { + hidden_size: 2048, + intermediate_size: 16384, + num_attention_heads: 8, + num_hidden_layers: 18, + num_key_value_heads: 1, + // Default values. + rope_theta: 10000., + head_dim: 256, + hidden_act: Some(candle_nn::Activation::GeluPytorchTanh), + hidden_activation: None, + attention_bias: false, + max_position_embeddings: 8192, + rms_norm_eps: 1e-6, + vocab_size: 257216, + }, + projection_dim: 2048, + } + } } #[derive(Clone, Debug)] @@ -102,6 +125,28 @@ impl Model { self.language_model.forward(input_ids, pos) } + pub fn forward_without_projection(&mut self, input_ids: &Tensor) -> Result<Tensor> { + self.clear_kv_cache(); + let input_embeds = self.language_model.embed_tokens().forward(input_ids)?; + self.language_model + .forward_embeds_without_projection(&input_embeds, None, 0) + } + pub fn setup_without_projection( + &mut self, + pixel_values: &Tensor, + input_ids: &Tensor, + ) -> Result<Tensor> { + self.clear_kv_cache(); + let image_features = self + .vision_tower + .forward(pixel_values)? + .apply(&self.multi_modal_projector)?; + let image_features = crate::models::clip::div_l2_norm(&image_features)?; + let text_features = self.language_model.embed_tokens().forward(input_ids)?; + let input_embeds = Tensor::cat(&[image_features, text_features], 1)?; + self.language_model + .forward_embeds_without_projection(&input_embeds, None, 0) + } pub fn clear_kv_cache(&mut self) { self.pos = 0; self.language_model.clear_kv_cache() |