summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-09-29 19:56:56 +0200
committerGitHub <noreply@github.com>2024-09-29 19:56:56 +0200
commit2f49e1b5349f4e56677ec0d3dc3fe98f9cbb87c7 (patch)
treef5816678f4c8ebe84098081b1b121677e70604dc /candle-transformers
parent0ebb38813b152432249dde6f64004f682b50975b (diff)
downloadcandle-2f49e1b5349f4e56677ec0d3dc3fe98f9cbb87c7.tar.gz
candle-2f49e1b5349f4e56677ec0d3dc3fe98f9cbb87c7.tar.bz2
candle-2f49e1b5349f4e56677ec0d3dc3fe98f9cbb87c7.zip
Add PaliGemma. (#2519)
* Add PaliGemma. * PaliGemma inference loop. * Running PaliGemma example. * Tweak the prompt.
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/gemma.rs20
-rw-r--r--candle-transformers/src/models/mod.rs1
-rw-r--r--candle-transformers/src/models/paligemma.rs109
3 files changed, 130 insertions, 0 deletions
diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs
index 1cfef59e..69e22678 100644
--- a/candle-transformers/src/models/gemma.rs
+++ b/candle-transformers/src/models/gemma.rs
@@ -362,6 +362,10 @@ impl Model {
})
}
+ pub fn embed_tokens(&self) -> &candle_nn::Embedding {
+ &self.embed_tokens
+ }
+
fn prepare_decoder_attention_mask(
&self,
b_size: usize,
@@ -400,6 +404,22 @@ impl Model {
.apply(&self.lm_head)
}
+ pub fn forward_embeds(
+ &mut self,
+ xs: &Tensor,
+ attn_mask: Option<&Tensor>,
+ seqlen_offset: usize,
+ ) -> Result<Tensor> {
+ let (_, seq_len, _) = xs.dims3()?;
+ let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
+ for layer in self.layers.iter_mut() {
+ xs = layer.forward(&xs, attn_mask, seqlen_offset)?
+ }
+ xs.narrow(1, seq_len - 1, 1)?
+ .apply(&self.norm)?
+ .apply(&self.lm_head)
+ }
+
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index a0e7a922..bba701bd 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -46,6 +46,7 @@ pub mod moondream;
pub mod mpt;
pub mod olmo;
pub mod openclip;
+pub mod paligemma;
pub mod parler_tts;
pub mod persimmon;
pub mod phi;
diff --git a/candle-transformers/src/models/paligemma.rs b/candle-transformers/src/models/paligemma.rs
new file mode 100644
index 00000000..e22ab241
--- /dev/null
+++ b/candle-transformers/src/models/paligemma.rs
@@ -0,0 +1,109 @@
+use crate::models::{gemma, siglip};
+use candle::{Module, Result, Tensor};
+use candle_nn::{linear, Linear, VarBuilder};
+
+#[derive(serde::Deserialize, Clone, Debug)]
+pub struct Config {
+ pub vision_config: siglip::VisionConfig,
+ pub text_config: gemma::Config,
+ pub projection_dim: usize,
+}
+
+impl Config {
+ pub fn paligemma_3b_224() -> Self {
+ // https://huggingface.co/google/paligemma-3b-pt-224/blob/main/config.json
+ Self {
+ vision_config: siglip::VisionConfig::paligemma_3b_224(),
+ text_config: gemma::Config {
+ hidden_size: 2048,
+ intermediate_size: 16384,
+ num_attention_heads: 8,
+ num_hidden_layers: 18,
+ num_key_value_heads: 1,
+ vocab_size: 257216,
+ // Default values.
+ rope_theta: 10000.,
+ head_dim: 256,
+ hidden_act: Some(candle_nn::Activation::GeluPytorchTanh),
+ hidden_activation: None,
+ attention_bias: false,
+ max_position_embeddings: 8192,
+ rms_norm_eps: 1e-6,
+ },
+ projection_dim: 2048,
+ }
+ }
+}
+
+#[derive(Clone, Debug)]
+pub struct MultiModalProjector {
+ linear: Linear,
+}
+
+impl MultiModalProjector {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let linear = linear(
+ cfg.vision_config.hidden_size,
+ cfg.projection_dim,
+ vb.pp("linear"),
+ )?;
+ Ok(Self { linear })
+ }
+}
+
+impl Module for MultiModalProjector {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.apply(&self.linear)
+ }
+}
+
+#[derive(Clone, Debug)]
+pub struct Model {
+ pos: usize,
+ vision_tower: siglip::VisionModel,
+ multi_modal_projector: MultiModalProjector,
+ language_model: gemma::Model,
+}
+
+impl Model {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let vision_tower = siglip::VisionModel::new(
+ &cfg.vision_config,
+ false,
+ vb.pp("vision_tower.vision_model"),
+ )?;
+ let multi_modal_projector = MultiModalProjector::new(cfg, vb.pp("multi_modal_projector"))?;
+ let language_model = gemma::Model::new(false, &cfg.text_config, vb.pp("language_model"))?;
+ Ok(Self {
+ pos: 0,
+ language_model,
+ vision_tower,
+ multi_modal_projector,
+ })
+ }
+
+ pub fn setup(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> {
+ self.clear_kv_cache();
+ let image_features = self
+ .vision_tower
+ .forward(pixel_values)?
+ .apply(&self.multi_modal_projector)?;
+ let image_features = crate::models::clip::div_l2_norm(&image_features)?;
+ let text_features = self.language_model.embed_tokens().forward(input_ids)?;
+ let input_embeds = Tensor::cat(&[image_features, text_features], 1)?;
+ self.pos = input_embeds.dim(1)?;
+ self.language_model.forward_embeds(&input_embeds, None, 0)
+ }
+
+ pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
+ let pos = self.pos;
+ let seq_len = input_ids.dim(1)?;
+ self.pos = pos + seq_len;
+ self.language_model.forward(input_ids, pos)
+ }
+
+ pub fn clear_kv_cache(&mut self) {
+ self.pos = 0;
+ self.language_model.clear_kv_cache()
+ }
+}