summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAmélie Royer <amelie.royer@ens-rennes.fr>2024-12-23 13:22:35 +0100
committerGitHub <noreply@github.com>2024-12-23 13:22:35 +0100
commit1be6b090c7920c35f5492845d219e3a99ce4d115 (patch)
treec9afb051b9850bf386aa2670d50b30cd6ef48c6f
parent62ced44ea94da7062430ed6c21ff17b36f41737d (diff)
downloadcandle-1be6b090c7920c35f5492845d219e3a99ce4d115.tar.gz
candle-1be6b090c7920c35f5492845d219e3a99ce4d115.tar.bz2
candle-1be6b090c7920c35f5492845d219e3a99ce4d115.zip
Fix position encodings for Pixtral (#2678)
* init commit: add position id in meshgrid * pass in subsampled positions * clippy fix * clippy fix
-rw-r--r--candle-transformers/src/models/pixtral/vision_model.rs68
1 files changed, 55 insertions, 13 deletions
diff --git a/candle-transformers/src/models/pixtral/vision_model.rs b/candle-transformers/src/models/pixtral/vision_model.rs
index 20d8f082..3f884aaf 100644
--- a/candle-transformers/src/models/pixtral/vision_model.rs
+++ b/candle-transformers/src/models/pixtral/vision_model.rs
@@ -1,8 +1,8 @@
-use candle::{DType, Module, Result, Tensor, D};
+use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{linear_b, rms_norm, Linear, RmsNorm, VarBuilder};
fn default_act() -> candle_nn::Activation {
- candle_nn::Activation::Gelu
+ candle_nn::Activation::Silu
}
fn default_hidden_size() -> usize {
@@ -58,7 +58,7 @@ impl Config {
num_attention_heads: 16,
head_dim: None,
// Default
- hidden_act: candle_nn::Activation::Gelu,
+ hidden_act: candle_nn::Activation::Silu,
}
}
@@ -104,6 +104,7 @@ impl Attention {
&self,
xs: &Tensor,
emb: &RotaryEmbedding,
+ subsampled_positions: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let (b, patches, _) = xs.dims3()?;
@@ -116,7 +117,8 @@ impl Attention {
let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
- let (query_states, key_states) = emb.apply_rotary_emb_qkv(&query_states, &key_states)?;
+ let (query_states, key_states) =
+ emb.apply_rotary_emb_qkv(&query_states, &key_states, subsampled_positions)?;
let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?;
let attn_weights = match attention_mask {
@@ -189,12 +191,16 @@ impl AttentionLayer {
&self,
xs: &Tensor,
emb: &RotaryEmbedding,
+ subsampled_positions: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let residual = xs;
- let xs = self
- .attention
- .forward(&xs.apply(&self.attention_norm)?, emb, attention_mask)?;
+ let xs = self.attention.forward(
+ &xs.apply(&self.attention_norm)?,
+ emb,
+ subsampled_positions,
+ attention_mask,
+ )?;
let xs = (residual + xs)?;
let residual = &xs;
let xs = xs.apply(&self.ffn_norm)?.apply(&self.feed_forward)?;
@@ -222,11 +228,12 @@ impl Transformer {
&self,
xs: &Tensor,
emb: &RotaryEmbedding,
+ subsampled_positions: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
- xs = layer.forward(&xs, emb, attention_mask)?
+ xs = layer.forward(&xs, emb, subsampled_positions, attention_mask)?
}
Ok(xs)
}
@@ -270,10 +277,20 @@ impl RotaryEmbedding {
Ok(Self { cos, sin })
}
- fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
+ fn apply_rotary_emb_qkv(
+ &self,
+ q: &Tensor,
+ k: &Tensor,
+ subsampled_positions: Option<&Tensor>,
+ ) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, _seq_len, _n_embd) = q.dims4()?;
- let cos = &self.cos;
- let sin = &self.sin;
+ let (cos, sin) = match subsampled_positions {
+ None => (&self.cos, &self.sin),
+ Some(pos) => (
+ &self.cos.index_select(pos, 0)?,
+ &self.sin.index_select(pos, 0)?,
+ ),
+ };
let q_embed = candle_nn::rotary_emb::rope(q, cos, sin)?;
let k_embed = candle_nn::rotary_emb::rope(k, cos, sin)?;
Ok((q_embed, k_embed))
@@ -286,6 +303,7 @@ pub struct Model {
ln_pre: RmsNorm,
transformer: Transformer,
patch_positional_embedding: RotaryEmbedding,
+ max_image_width: u32,
}
impl Model {
@@ -305,20 +323,44 @@ impl Model {
let transformer = Transformer::new(cfg, vb.pp("transformer"))?;
let patch_positional_embedding =
RotaryEmbedding::new(cfg, vb.pp("patch_positional_embedding"))?;
+ let max_image_width = (cfg.image_size / cfg.patch_size) as u32;
Ok(Self {
patch_conv,
ln_pre,
transformer,
patch_positional_embedding,
+ max_image_width,
})
}
+
+ pub fn position_ids_in_meshgrid(
+ &self,
+ num_patches_h: usize,
+ num_patches_w: usize,
+ device: &Device,
+ ) -> Result<Tensor> {
+ let idx = Tensor::arange(0, num_patches_h as u32, device)?;
+ let idy = Tensor::arange(0, num_patches_w as u32, device)?;
+ let mesh = Tensor::meshgrid(&[idx, idy], false)?;
+ let ids = (&mesh[0] * (self.max_image_width as f64) + &mesh[1])?.flatten_all()?;
+ Ok(ids)
+ }
}
impl Module for Model {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let patch_embeds = xs.apply(&self.patch_conv)?;
+ let subsampled_positions = Some(self.position_ids_in_meshgrid(
+ patch_embeds.dim(2)?,
+ patch_embeds.dim(3)?,
+ patch_embeds.device(),
+ )?);
let patch_embeds = patch_embeds.flatten_from(2)?.t()?.apply(&self.ln_pre)?;
- self.transformer
- .forward(&patch_embeds, &self.patch_positional_embedding, None)
+ self.transformer.forward(
+ &patch_embeds,
+ &self.patch_positional_embedding,
+ subsampled_positions.as_ref(),
+ None,
+ )
}
}