diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-09-30 19:31:14 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-30 19:31:14 +0200 |
commit | 683ab698def755c24cec9987069d25efcf831fc4 (patch) | |
tree | 84d0bd8ad2f5d7a00f67050c83520326d947b2fe /candle-transformers | |
parent | 2f49e1b5349f4e56677ec0d3dc3fe98f9cbb87c7 (diff) | |
download | candle-683ab698def755c24cec9987069d25efcf831fc4.tar.gz candle-683ab698def755c24cec9987069d25efcf831fc4.tar.bz2 candle-683ab698def755c24cec9987069d25efcf831fc4.zip |
Add Pixtral. (#2521)
* Add Pixtral.
* More pixtral vision encoder.
* Sketch a pixtral example.
* Sketch a pixtral example.
* Better image loading.
* Support loading images embedded in safetensor files.
* Clippy fixes.
* Add the llava multimodal adapter.
* Add more of the llava bits.
* Add the pixtral config.
* More pixtral inference.
* Add the text generation bits.
* Get the example to work.
* Bugfix.
* Run some bits of the model in f32.
* Blessed version :)
* Better rope frequency computations.
* README update.
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/llava/mod.rs | 2 | ||||
-rw-r--r-- | candle-transformers/src/models/mistral.rs | 38 | ||||
-rw-r--r-- | candle-transformers/src/models/mod.rs | 1 | ||||
-rw-r--r-- | candle-transformers/src/models/pixtral/llava.rs | 72 | ||||
-rw-r--r-- | candle-transformers/src/models/pixtral/mod.rs | 4 | ||||
-rw-r--r-- | candle-transformers/src/models/pixtral/vision_model.rs | 324 |
6 files changed, 436 insertions, 5 deletions
diff --git a/candle-transformers/src/models/llava/mod.rs b/candle-transformers/src/models/llava/mod.rs index caa8737a..1ed3b50c 100644 --- a/candle-transformers/src/models/llava/mod.rs +++ b/candle-transformers/src/models/llava/mod.rs @@ -279,7 +279,7 @@ impl LLaVA { (), ))? } else { - todo!("not implemented in original python LLaVA yet") + bail!("not implemented in original python LLaVA yet") }; let new_image_feature = if mm_patch_merge_type.contains("unpad") { let new_image_feature = new_image_feature diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index 7e3b21c9..e8f7a7c4 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -4,19 +4,29 @@ use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; use std::sync::Arc; +fn default_num_attention_heads() -> usize { + 32 +} + fn default_use_flash_attn() -> bool { false } +fn default_hidden_act() -> candle_nn::Activation { + candle_nn::Activation::Silu +} + #[derive(Debug, Clone, PartialEq, serde::Deserialize)] pub struct Config { pub vocab_size: usize, pub hidden_size: usize, pub intermediate_size: usize, pub num_hidden_layers: usize, + #[serde(default = "default_num_attention_heads")] pub num_attention_heads: usize, pub head_dim: Option<usize>, pub num_key_value_heads: usize, + #[serde(default = "default_hidden_act")] pub hidden_act: Activation, pub max_position_embeddings: usize, pub rms_norm_eps: f64, @@ -107,14 +117,14 @@ impl RotaryEmbedding { .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32)) .collect(); let inv_freq_len = inv_freq.len(); - let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?; let t = Tensor::arange(0u32, max_seq_len as u32, dev)? - .to_dtype(dtype)? + .to_dtype(DType::F32)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; Ok(Self { - sin: freqs.sin()?, - cos: freqs.cos()?, + sin: freqs.sin()?.to_dtype(dtype)?, + cos: freqs.cos()?.to_dtype(dtype)?, }) } @@ -404,6 +414,10 @@ impl Model { .to_dtype(self.dtype) } + pub fn embed_tokens(&self) -> &candle_nn::Embedding { + &self.embed_tokens + } + pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> { let (_b_size, seq_len) = input_ids.dims2()?; let attention_mask = if seq_len <= 1 { @@ -421,6 +435,22 @@ impl Model { .apply(&self.lm_head) } + pub fn forward_embeds( + &mut self, + xs: &Tensor, + attn_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result<Tensor> { + let (_b_size, seq_len, _) = xs.dims3()?; + let mut xs = xs.clone(); + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attn_mask, seqlen_offset)? + } + xs.narrow(1, seq_len - 1, 1)? + .apply(&self.norm)? + .apply(&self.lm_head) + } + pub fn clear_kv_cache(&mut self) { for layer in self.layers.iter_mut() { layer.clear_kv_cache() diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index bba701bd..09876503 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -51,6 +51,7 @@ pub mod parler_tts; pub mod persimmon; pub mod phi; pub mod phi3; +pub mod pixtral; pub mod quantized_blip; pub mod quantized_blip_text; pub mod quantized_llama; diff --git a/candle-transformers/src/models/pixtral/llava.rs b/candle-transformers/src/models/pixtral/llava.rs new file mode 100644 index 00000000..33e0aca0 --- /dev/null +++ b/candle-transformers/src/models/pixtral/llava.rs @@ -0,0 +1,72 @@ +use candle::{Module, Result, Tensor}; +use candle_nn::{linear, Linear, VarBuilder}; + +use super::vision_model; +use crate::models::mistral; + +#[derive(serde::Deserialize, Debug, Clone)] +pub struct Config { + pub projector_hidden_act: candle_nn::Activation, + pub text_config: mistral::Config, + pub vision_config: vision_model::Config, + pub image_token_index: usize, + pub image_seq_length: usize, +} + +#[derive(Debug, Clone)] +pub struct MultiModalProjector { + linear_1: Linear, + act: candle_nn::Activation, + linear_2: Linear, +} + +impl MultiModalProjector { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let (hidden_v, hidden_t) = (cfg.vision_config.hidden_size, cfg.text_config.hidden_size); + let linear_1 = linear(hidden_v, hidden_t, vb.pp("linear_1"))?; + let linear_2 = linear(hidden_t, hidden_t, vb.pp("linear_2"))?; + Ok(Self { + linear_1, + act: cfg.projector_hidden_act, + linear_2, + }) + } +} + +impl Module for MultiModalProjector { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + xs.apply(&self.linear_1)? + .apply(&self.act)? + .apply(&self.linear_2) + } +} + +#[derive(Debug, Clone)] +pub struct Model { + pub multi_modal_projector: MultiModalProjector, + pub language_model: mistral::Model, + pub vision_tower: vision_model::Model, + pub patch_size: usize, + pub dtype: candle::DType, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let language_model = mistral::Model::new(&cfg.text_config, vb.pp("language_model"))?; + let vision_tower = vision_model::Model::new( + &cfg.vision_config, + vb.pp("vision_tower").to_dtype(candle::DType::F32), + )?; + let multi_modal_projector = MultiModalProjector::new( + cfg, + vb.pp("multi_modal_projector").to_dtype(candle::DType::F32), + )?; + Ok(Self { + multi_modal_projector, + language_model, + vision_tower, + patch_size: cfg.vision_config.patch_size, + dtype: vb.dtype(), + }) + } +} diff --git a/candle-transformers/src/models/pixtral/mod.rs b/candle-transformers/src/models/pixtral/mod.rs new file mode 100644 index 00000000..9d0eccfb --- /dev/null +++ b/candle-transformers/src/models/pixtral/mod.rs @@ -0,0 +1,4 @@ +pub mod llava; +pub mod vision_model; + +pub use llava::{Config, Model}; diff --git a/candle-transformers/src/models/pixtral/vision_model.rs b/candle-transformers/src/models/pixtral/vision_model.rs new file mode 100644 index 00000000..20d8f082 --- /dev/null +++ b/candle-transformers/src/models/pixtral/vision_model.rs @@ -0,0 +1,324 @@ +use candle::{DType, Module, Result, Tensor, D}; +use candle_nn::{linear_b, rms_norm, Linear, RmsNorm, VarBuilder}; + +fn default_act() -> candle_nn::Activation { + candle_nn::Activation::Gelu +} + +fn default_hidden_size() -> usize { + 1024 +} + +fn default_intermediate_size() -> usize { + 4096 +} + +fn default_num_channels() -> usize { + 3 +} + +fn default_num_hidden_layers() -> usize { + 24 +} + +fn default_num_attention_heads() -> usize { + 16 +} + +#[derive(serde::Deserialize, Debug, Clone)] +pub struct Config { + #[serde(default = "default_hidden_size")] + pub hidden_size: usize, + #[serde(default = "default_num_channels")] + pub num_channels: usize, + pub image_size: usize, + pub patch_size: usize, + pub rope_theta: f64, + #[serde(default = "default_intermediate_size")] + pub intermediate_size: usize, + #[serde(default = "default_num_hidden_layers")] + pub num_hidden_layers: usize, + pub head_dim: Option<usize>, + #[serde(default = "default_num_attention_heads")] + pub num_attention_heads: usize, + #[serde(default = "default_act")] + pub hidden_act: candle_nn::Activation, +} + +impl Config { + pub fn pixtral_12b_2409() -> Self { + Self { + hidden_size: 1024, + num_channels: 3, + image_size: 1024, + patch_size: 16, + rope_theta: 10000.0, + intermediate_size: 4096, + num_hidden_layers: 24, + num_attention_heads: 16, + head_dim: None, + // Default + hidden_act: candle_nn::Activation::Gelu, + } + } + + fn head_dim(&self) -> usize { + self.head_dim + .unwrap_or(self.hidden_size / self.num_attention_heads) + } +} + +#[derive(Debug, Clone)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + scale: f64, + num_heads: usize, + head_dim: usize, +} + +impl Attention { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let h = cfg.hidden_size; + let num_heads = cfg.num_attention_heads; + let head_dim = cfg.head_dim(); + let q_proj = linear_b(h, h, false, vb.pp("q_proj"))?; + let k_proj = linear_b(h, h, false, vb.pp("k_proj"))?; + let v_proj = linear_b(h, h, false, vb.pp("v_proj"))?; + let o_proj = linear_b(h, h, false, vb.pp("o_proj"))?; + let scale = (head_dim as f64).powf(-0.5); + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + scale, + num_heads, + head_dim, + }) + } + + fn forward( + &self, + xs: &Tensor, + emb: &RotaryEmbedding, + attention_mask: Option<&Tensor>, + ) -> Result<Tensor> { + let (b, patches, _) = xs.dims3()?; + let query_states = xs.apply(&self.q_proj)?; + let key_states = xs.apply(&self.k_proj)?; + let value_states = xs.apply(&self.v_proj)?; + + let shape = (b, patches, self.num_heads, self.head_dim); + let query_states = query_states.reshape(shape)?.transpose(1, 2)?.contiguous()?; + let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?; + let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?; + + let (query_states, key_states) = emb.apply_rotary_emb_qkv(&query_states, &key_states)?; + let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?; + + let attn_weights = match attention_mask { + None => attn_weights, + Some(mask) => attn_weights.broadcast_add(mask)?, + }; + + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + attn_weights + .matmul(&value_states)? + .transpose(1, 2)? + .reshape((b, patches, ()))? + .apply(&self.o_proj) + } +} + +#[derive(Debug, Clone)] +struct Mlp { + gate_proj: Linear, + up_proj: Linear, + down_proj: Linear, + act_fn: candle_nn::Activation, +} + +impl Mlp { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let (h, i) = (cfg.hidden_size, cfg.intermediate_size); + let gate_proj = linear_b(h, i, false, vb.pp("gate_proj"))?; + let up_proj = linear_b(h, i, false, vb.pp("up_proj"))?; + let down_proj = linear_b(i, h, false, vb.pp("down_proj"))?; + Ok(Self { + gate_proj, + up_proj, + down_proj, + act_fn: cfg.hidden_act, + }) + } +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + (xs.apply(&self.gate_proj)?.apply(&self.act_fn)? * xs.apply(&self.up_proj))? + .apply(&self.down_proj) + } +} + +#[derive(Debug, Clone)] +struct AttentionLayer { + attention_norm: RmsNorm, + feed_forward: Mlp, + attention: Attention, + ffn_norm: RmsNorm, +} + +impl AttentionLayer { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let attention_norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("attention_norm"))?; + let feed_forward = Mlp::new(cfg, vb.pp("feed_forward"))?; + let attention = Attention::new(cfg, vb.pp("attention"))?; + let ffn_norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("ffn_norm"))?; + Ok(Self { + attention_norm, + feed_forward, + attention, + ffn_norm, + }) + } + + fn forward( + &self, + xs: &Tensor, + emb: &RotaryEmbedding, + attention_mask: Option<&Tensor>, + ) -> Result<Tensor> { + let residual = xs; + let xs = self + .attention + .forward(&xs.apply(&self.attention_norm)?, emb, attention_mask)?; + let xs = (residual + xs)?; + let residual = &xs; + let xs = xs.apply(&self.ffn_norm)?.apply(&self.feed_forward)?; + xs + residual + } +} + +#[derive(Debug, Clone)] +struct Transformer { + layers: Vec<AttentionLayer>, +} + +impl Transformer { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb = vb.pp("layers"); + for layer_idx in 0..cfg.num_hidden_layers { + let layer = AttentionLayer::new(cfg, vb.pp(layer_idx))?; + layers.push(layer) + } + Ok(Self { layers }) + } + + fn forward( + &self, + xs: &Tensor, + emb: &RotaryEmbedding, + attention_mask: Option<&Tensor>, + ) -> Result<Tensor> { + let mut xs = xs.clone(); + for layer in self.layers.iter() { + xs = layer.forward(&xs, emb, attention_mask)? + } + Ok(xs) + } +} + +#[derive(Debug, Clone)] +struct RotaryEmbedding { + cos: Tensor, + sin: Tensor, +} + +impl RotaryEmbedding { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let dtype = vb.dtype(); + let dev = vb.device(); + let dim = cfg.head_dim(); + let rope_theta = cfg.rope_theta as f32; + let max_patches_per_side = cfg.image_size / cfg.patch_size; + let freqs: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32)) + .collect(); + let freqs_h = freqs.iter().step_by(2).copied().collect::<Vec<_>>(); + let freqs_h = Tensor::new(freqs_h, dev)?; + let freqs_w = freqs.iter().skip(1).step_by(2).copied().collect::<Vec<_>>(); + let freqs_w = Tensor::new(freqs_w, dev)?; + let h = Tensor::arange(0u32, max_patches_per_side as u32, dev)?.to_dtype(DType::F32)?; + let w = Tensor::arange(0u32, max_patches_per_side as u32, dev)?.to_dtype(DType::F32)?; + let freqs_h = h.unsqueeze(1)?.matmul(&freqs_h.unsqueeze(0)?)?; + let freqs_w = w.unsqueeze(1)?.matmul(&freqs_w.unsqueeze(0)?)?; + let inv_freq = Tensor::cat( + &[ + freqs_h.unsqueeze(1)?.repeat((1, max_patches_per_side, 1))?, + freqs_w.unsqueeze(0)?.repeat((max_patches_per_side, 1, 1))?, + ], + D::Minus1, + )? + .reshape(((), dim / 2))?; + let cos = inv_freq.cos()?.to_dtype(dtype)?; + let sin = inv_freq.sin()?.to_dtype(dtype)?; + Ok(Self { cos, sin }) + } + + fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> { + let (_b_sz, _h, _seq_len, _n_embd) = q.dims4()?; + let cos = &self.cos; + let sin = &self.sin; + let q_embed = candle_nn::rotary_emb::rope(q, cos, sin)?; + let k_embed = candle_nn::rotary_emb::rope(k, cos, sin)?; + Ok((q_embed, k_embed)) + } +} + +#[derive(Debug, Clone)] +pub struct Model { + patch_conv: candle_nn::Conv2d, + ln_pre: RmsNorm, + transformer: Transformer, + patch_positional_embedding: RotaryEmbedding, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let conv2d_cfg = candle_nn::Conv2dConfig { + stride: cfg.patch_size, + ..Default::default() + }; + let patch_conv = candle_nn::conv2d_no_bias( + cfg.num_channels, + cfg.hidden_size, + cfg.patch_size, + conv2d_cfg, + vb.pp("patch_conv"), + )?; + let ln_pre = candle_nn::rms_norm(cfg.hidden_size, 1e-5, vb.pp("ln_pre"))?; + let transformer = Transformer::new(cfg, vb.pp("transformer"))?; + let patch_positional_embedding = + RotaryEmbedding::new(cfg, vb.pp("patch_positional_embedding"))?; + Ok(Self { + patch_conv, + ln_pre, + transformer, + patch_positional_embedding, + }) + } +} + +impl Module for Model { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let patch_embeds = xs.apply(&self.patch_conv)?; + let patch_embeds = patch_embeds.flatten_from(2)?.t()?.apply(&self.ln_pre)?; + self.transformer + .forward(&patch_embeds, &self.patch_positional_embedding, None) + } +} |