diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-09-29 19:56:56 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-29 19:56:56 +0200 |
commit | 2f49e1b5349f4e56677ec0d3dc3fe98f9cbb87c7 (patch) | |
tree | f5816678f4c8ebe84098081b1b121677e70604dc /candle-transformers | |
parent | 0ebb38813b152432249dde6f64004f682b50975b (diff) | |
download | candle-2f49e1b5349f4e56677ec0d3dc3fe98f9cbb87c7.tar.gz candle-2f49e1b5349f4e56677ec0d3dc3fe98f9cbb87c7.tar.bz2 candle-2f49e1b5349f4e56677ec0d3dc3fe98f9cbb87c7.zip |
Add PaliGemma. (#2519)
* Add PaliGemma.
* PaliGemma inference loop.
* Running PaliGemma example.
* Tweak the prompt.
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/gemma.rs | 20 | ||||
-rw-r--r-- | candle-transformers/src/models/mod.rs | 1 | ||||
-rw-r--r-- | candle-transformers/src/models/paligemma.rs | 109 |
3 files changed, 130 insertions, 0 deletions
diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs index 1cfef59e..69e22678 100644 --- a/candle-transformers/src/models/gemma.rs +++ b/candle-transformers/src/models/gemma.rs @@ -362,6 +362,10 @@ impl Model { }) } + pub fn embed_tokens(&self) -> &candle_nn::Embedding { + &self.embed_tokens + } + fn prepare_decoder_attention_mask( &self, b_size: usize, @@ -400,6 +404,22 @@ impl Model { .apply(&self.lm_head) } + pub fn forward_embeds( + &mut self, + xs: &Tensor, + attn_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result<Tensor> { + let (_, seq_len, _) = xs.dims3()?; + let mut xs = (xs * (self.hidden_size as f64).sqrt())?; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attn_mask, seqlen_offset)? + } + xs.narrow(1, seq_len - 1, 1)? + .apply(&self.norm)? + .apply(&self.lm_head) + } + pub fn clear_kv_cache(&mut self) { for layer in self.layers.iter_mut() { layer.clear_kv_cache() diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index a0e7a922..bba701bd 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -46,6 +46,7 @@ pub mod moondream; pub mod mpt; pub mod olmo; pub mod openclip; +pub mod paligemma; pub mod parler_tts; pub mod persimmon; pub mod phi; diff --git a/candle-transformers/src/models/paligemma.rs b/candle-transformers/src/models/paligemma.rs new file mode 100644 index 00000000..e22ab241 --- /dev/null +++ b/candle-transformers/src/models/paligemma.rs @@ -0,0 +1,109 @@ +use crate::models::{gemma, siglip}; +use candle::{Module, Result, Tensor}; +use candle_nn::{linear, Linear, VarBuilder}; + +#[derive(serde::Deserialize, Clone, Debug)] +pub struct Config { + pub vision_config: siglip::VisionConfig, + pub text_config: gemma::Config, + pub projection_dim: usize, +} + +impl Config { + pub fn paligemma_3b_224() -> Self { + // https://huggingface.co/google/paligemma-3b-pt-224/blob/main/config.json + Self { + vision_config: siglip::VisionConfig::paligemma_3b_224(), + text_config: gemma::Config { + hidden_size: 2048, + intermediate_size: 16384, + num_attention_heads: 8, + num_hidden_layers: 18, + num_key_value_heads: 1, + vocab_size: 257216, + // Default values. + rope_theta: 10000., + head_dim: 256, + hidden_act: Some(candle_nn::Activation::GeluPytorchTanh), + hidden_activation: None, + attention_bias: false, + max_position_embeddings: 8192, + rms_norm_eps: 1e-6, + }, + projection_dim: 2048, + } + } +} + +#[derive(Clone, Debug)] +pub struct MultiModalProjector { + linear: Linear, +} + +impl MultiModalProjector { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let linear = linear( + cfg.vision_config.hidden_size, + cfg.projection_dim, + vb.pp("linear"), + )?; + Ok(Self { linear }) + } +} + +impl Module for MultiModalProjector { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + xs.apply(&self.linear) + } +} + +#[derive(Clone, Debug)] +pub struct Model { + pos: usize, + vision_tower: siglip::VisionModel, + multi_modal_projector: MultiModalProjector, + language_model: gemma::Model, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let vision_tower = siglip::VisionModel::new( + &cfg.vision_config, + false, + vb.pp("vision_tower.vision_model"), + )?; + let multi_modal_projector = MultiModalProjector::new(cfg, vb.pp("multi_modal_projector"))?; + let language_model = gemma::Model::new(false, &cfg.text_config, vb.pp("language_model"))?; + Ok(Self { + pos: 0, + language_model, + vision_tower, + multi_modal_projector, + }) + } + + pub fn setup(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> { + self.clear_kv_cache(); + let image_features = self + .vision_tower + .forward(pixel_values)? + .apply(&self.multi_modal_projector)?; + let image_features = crate::models::clip::div_l2_norm(&image_features)?; + let text_features = self.language_model.embed_tokens().forward(input_ids)?; + let input_embeds = Tensor::cat(&[image_features, text_features], 1)?; + self.pos = input_embeds.dim(1)?; + self.language_model.forward_embeds(&input_embeds, None, 0) + } + + pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> { + let pos = self.pos; + let seq_len = input_ids.dim(1)?; + self.pos = pos + seq_len; + self.language_model.forward(input_ids, pos) + } + + pub fn clear_kv_cache(&mut self) { + self.pos = 0; + self.language_model.clear_kv_cache() + } +} |