summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-09-30 19:31:14 +0200
committerGitHub <noreply@github.com>2024-09-30 19:31:14 +0200
commit683ab698def755c24cec9987069d25efcf831fc4 (patch)
tree84d0bd8ad2f5d7a00f67050c83520326d947b2fe /candle-transformers
parent2f49e1b5349f4e56677ec0d3dc3fe98f9cbb87c7 (diff)
downloadcandle-683ab698def755c24cec9987069d25efcf831fc4.tar.gz
candle-683ab698def755c24cec9987069d25efcf831fc4.tar.bz2
candle-683ab698def755c24cec9987069d25efcf831fc4.zip
Add Pixtral. (#2521)
* Add Pixtral. * More pixtral vision encoder. * Sketch a pixtral example. * Sketch a pixtral example. * Better image loading. * Support loading images embedded in safetensor files. * Clippy fixes. * Add the llava multimodal adapter. * Add more of the llava bits. * Add the pixtral config. * More pixtral inference. * Add the text generation bits. * Get the example to work. * Bugfix. * Run some bits of the model in f32. * Blessed version :) * Better rope frequency computations. * README update.
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/llava/mod.rs2
-rw-r--r--candle-transformers/src/models/mistral.rs38
-rw-r--r--candle-transformers/src/models/mod.rs1
-rw-r--r--candle-transformers/src/models/pixtral/llava.rs72
-rw-r--r--candle-transformers/src/models/pixtral/mod.rs4
-rw-r--r--candle-transformers/src/models/pixtral/vision_model.rs324
6 files changed, 436 insertions, 5 deletions
diff --git a/candle-transformers/src/models/llava/mod.rs b/candle-transformers/src/models/llava/mod.rs
index caa8737a..1ed3b50c 100644
--- a/candle-transformers/src/models/llava/mod.rs
+++ b/candle-transformers/src/models/llava/mod.rs
@@ -279,7 +279,7 @@ impl LLaVA {
(),
))?
} else {
- todo!("not implemented in original python LLaVA yet")
+ bail!("not implemented in original python LLaVA yet")
};
let new_image_feature = if mm_patch_merge_type.contains("unpad") {
let new_image_feature = new_image_feature
diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs
index 7e3b21c9..e8f7a7c4 100644
--- a/candle-transformers/src/models/mistral.rs
+++ b/candle-transformers/src/models/mistral.rs
@@ -4,19 +4,29 @@ use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use std::sync::Arc;
+fn default_num_attention_heads() -> usize {
+ 32
+}
+
fn default_use_flash_attn() -> bool {
false
}
+fn default_hidden_act() -> candle_nn::Activation {
+ candle_nn::Activation::Silu
+}
+
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct Config {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
+ #[serde(default = "default_num_attention_heads")]
pub num_attention_heads: usize,
pub head_dim: Option<usize>,
pub num_key_value_heads: usize,
+ #[serde(default = "default_hidden_act")]
pub hidden_act: Activation,
pub max_position_embeddings: usize,
pub rms_norm_eps: f64,
@@ -107,14 +117,14 @@ impl RotaryEmbedding {
.map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
.collect();
let inv_freq_len = inv_freq.len();
- let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
+ let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
- .to_dtype(dtype)?
+ .to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
- sin: freqs.sin()?,
- cos: freqs.cos()?,
+ sin: freqs.sin()?.to_dtype(dtype)?,
+ cos: freqs.cos()?.to_dtype(dtype)?,
})
}
@@ -404,6 +414,10 @@ impl Model {
.to_dtype(self.dtype)
}
+ pub fn embed_tokens(&self) -> &candle_nn::Embedding {
+ &self.embed_tokens
+ }
+
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (_b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
@@ -421,6 +435,22 @@ impl Model {
.apply(&self.lm_head)
}
+ pub fn forward_embeds(
+ &mut self,
+ xs: &Tensor,
+ attn_mask: Option<&Tensor>,
+ seqlen_offset: usize,
+ ) -> Result<Tensor> {
+ let (_b_size, seq_len, _) = xs.dims3()?;
+ let mut xs = xs.clone();
+ for layer in self.layers.iter_mut() {
+ xs = layer.forward(&xs, attn_mask, seqlen_offset)?
+ }
+ xs.narrow(1, seq_len - 1, 1)?
+ .apply(&self.norm)?
+ .apply(&self.lm_head)
+ }
+
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index bba701bd..09876503 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -51,6 +51,7 @@ pub mod parler_tts;
pub mod persimmon;
pub mod phi;
pub mod phi3;
+pub mod pixtral;
pub mod quantized_blip;
pub mod quantized_blip_text;
pub mod quantized_llama;
diff --git a/candle-transformers/src/models/pixtral/llava.rs b/candle-transformers/src/models/pixtral/llava.rs
new file mode 100644
index 00000000..33e0aca0
--- /dev/null
+++ b/candle-transformers/src/models/pixtral/llava.rs
@@ -0,0 +1,72 @@
+use candle::{Module, Result, Tensor};
+use candle_nn::{linear, Linear, VarBuilder};
+
+use super::vision_model;
+use crate::models::mistral;
+
+#[derive(serde::Deserialize, Debug, Clone)]
+pub struct Config {
+ pub projector_hidden_act: candle_nn::Activation,
+ pub text_config: mistral::Config,
+ pub vision_config: vision_model::Config,
+ pub image_token_index: usize,
+ pub image_seq_length: usize,
+}
+
+#[derive(Debug, Clone)]
+pub struct MultiModalProjector {
+ linear_1: Linear,
+ act: candle_nn::Activation,
+ linear_2: Linear,
+}
+
+impl MultiModalProjector {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let (hidden_v, hidden_t) = (cfg.vision_config.hidden_size, cfg.text_config.hidden_size);
+ let linear_1 = linear(hidden_v, hidden_t, vb.pp("linear_1"))?;
+ let linear_2 = linear(hidden_t, hidden_t, vb.pp("linear_2"))?;
+ Ok(Self {
+ linear_1,
+ act: cfg.projector_hidden_act,
+ linear_2,
+ })
+ }
+}
+
+impl Module for MultiModalProjector {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.apply(&self.linear_1)?
+ .apply(&self.act)?
+ .apply(&self.linear_2)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct Model {
+ pub multi_modal_projector: MultiModalProjector,
+ pub language_model: mistral::Model,
+ pub vision_tower: vision_model::Model,
+ pub patch_size: usize,
+ pub dtype: candle::DType,
+}
+
+impl Model {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let language_model = mistral::Model::new(&cfg.text_config, vb.pp("language_model"))?;
+ let vision_tower = vision_model::Model::new(
+ &cfg.vision_config,
+ vb.pp("vision_tower").to_dtype(candle::DType::F32),
+ )?;
+ let multi_modal_projector = MultiModalProjector::new(
+ cfg,
+ vb.pp("multi_modal_projector").to_dtype(candle::DType::F32),
+ )?;
+ Ok(Self {
+ multi_modal_projector,
+ language_model,
+ vision_tower,
+ patch_size: cfg.vision_config.patch_size,
+ dtype: vb.dtype(),
+ })
+ }
+}
diff --git a/candle-transformers/src/models/pixtral/mod.rs b/candle-transformers/src/models/pixtral/mod.rs
new file mode 100644
index 00000000..9d0eccfb
--- /dev/null
+++ b/candle-transformers/src/models/pixtral/mod.rs
@@ -0,0 +1,4 @@
+pub mod llava;
+pub mod vision_model;
+
+pub use llava::{Config, Model};
diff --git a/candle-transformers/src/models/pixtral/vision_model.rs b/candle-transformers/src/models/pixtral/vision_model.rs
new file mode 100644
index 00000000..20d8f082
--- /dev/null
+++ b/candle-transformers/src/models/pixtral/vision_model.rs
@@ -0,0 +1,324 @@
+use candle::{DType, Module, Result, Tensor, D};
+use candle_nn::{linear_b, rms_norm, Linear, RmsNorm, VarBuilder};
+
+fn default_act() -> candle_nn::Activation {
+ candle_nn::Activation::Gelu
+}
+
+fn default_hidden_size() -> usize {
+ 1024
+}
+
+fn default_intermediate_size() -> usize {
+ 4096
+}
+
+fn default_num_channels() -> usize {
+ 3
+}
+
+fn default_num_hidden_layers() -> usize {
+ 24
+}
+
+fn default_num_attention_heads() -> usize {
+ 16
+}
+
+#[derive(serde::Deserialize, Debug, Clone)]
+pub struct Config {
+ #[serde(default = "default_hidden_size")]
+ pub hidden_size: usize,
+ #[serde(default = "default_num_channels")]
+ pub num_channels: usize,
+ pub image_size: usize,
+ pub patch_size: usize,
+ pub rope_theta: f64,
+ #[serde(default = "default_intermediate_size")]
+ pub intermediate_size: usize,
+ #[serde(default = "default_num_hidden_layers")]
+ pub num_hidden_layers: usize,
+ pub head_dim: Option<usize>,
+ #[serde(default = "default_num_attention_heads")]
+ pub num_attention_heads: usize,
+ #[serde(default = "default_act")]
+ pub hidden_act: candle_nn::Activation,
+}
+
+impl Config {
+ pub fn pixtral_12b_2409() -> Self {
+ Self {
+ hidden_size: 1024,
+ num_channels: 3,
+ image_size: 1024,
+ patch_size: 16,
+ rope_theta: 10000.0,
+ intermediate_size: 4096,
+ num_hidden_layers: 24,
+ num_attention_heads: 16,
+ head_dim: None,
+ // Default
+ hidden_act: candle_nn::Activation::Gelu,
+ }
+ }
+
+ fn head_dim(&self) -> usize {
+ self.head_dim
+ .unwrap_or(self.hidden_size / self.num_attention_heads)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct Attention {
+ q_proj: Linear,
+ k_proj: Linear,
+ v_proj: Linear,
+ o_proj: Linear,
+ scale: f64,
+ num_heads: usize,
+ head_dim: usize,
+}
+
+impl Attention {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let h = cfg.hidden_size;
+ let num_heads = cfg.num_attention_heads;
+ let head_dim = cfg.head_dim();
+ let q_proj = linear_b(h, h, false, vb.pp("q_proj"))?;
+ let k_proj = linear_b(h, h, false, vb.pp("k_proj"))?;
+ let v_proj = linear_b(h, h, false, vb.pp("v_proj"))?;
+ let o_proj = linear_b(h, h, false, vb.pp("o_proj"))?;
+ let scale = (head_dim as f64).powf(-0.5);
+ Ok(Self {
+ q_proj,
+ k_proj,
+ v_proj,
+ o_proj,
+ scale,
+ num_heads,
+ head_dim,
+ })
+ }
+
+ fn forward(
+ &self,
+ xs: &Tensor,
+ emb: &RotaryEmbedding,
+ attention_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let (b, patches, _) = xs.dims3()?;
+ let query_states = xs.apply(&self.q_proj)?;
+ let key_states = xs.apply(&self.k_proj)?;
+ let value_states = xs.apply(&self.v_proj)?;
+
+ let shape = (b, patches, self.num_heads, self.head_dim);
+ let query_states = query_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
+ let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
+ let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
+
+ let (query_states, key_states) = emb.apply_rotary_emb_qkv(&query_states, &key_states)?;
+ let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?;
+
+ let attn_weights = match attention_mask {
+ None => attn_weights,
+ Some(mask) => attn_weights.broadcast_add(mask)?,
+ };
+
+ let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
+ attn_weights
+ .matmul(&value_states)?
+ .transpose(1, 2)?
+ .reshape((b, patches, ()))?
+ .apply(&self.o_proj)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct Mlp {
+ gate_proj: Linear,
+ up_proj: Linear,
+ down_proj: Linear,
+ act_fn: candle_nn::Activation,
+}
+
+impl Mlp {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let (h, i) = (cfg.hidden_size, cfg.intermediate_size);
+ let gate_proj = linear_b(h, i, false, vb.pp("gate_proj"))?;
+ let up_proj = linear_b(h, i, false, vb.pp("up_proj"))?;
+ let down_proj = linear_b(i, h, false, vb.pp("down_proj"))?;
+ Ok(Self {
+ gate_proj,
+ up_proj,
+ down_proj,
+ act_fn: cfg.hidden_act,
+ })
+ }
+}
+
+impl Module for Mlp {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ (xs.apply(&self.gate_proj)?.apply(&self.act_fn)? * xs.apply(&self.up_proj))?
+ .apply(&self.down_proj)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct AttentionLayer {
+ attention_norm: RmsNorm,
+ feed_forward: Mlp,
+ attention: Attention,
+ ffn_norm: RmsNorm,
+}
+
+impl AttentionLayer {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let attention_norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("attention_norm"))?;
+ let feed_forward = Mlp::new(cfg, vb.pp("feed_forward"))?;
+ let attention = Attention::new(cfg, vb.pp("attention"))?;
+ let ffn_norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("ffn_norm"))?;
+ Ok(Self {
+ attention_norm,
+ feed_forward,
+ attention,
+ ffn_norm,
+ })
+ }
+
+ fn forward(
+ &self,
+ xs: &Tensor,
+ emb: &RotaryEmbedding,
+ attention_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let residual = xs;
+ let xs = self
+ .attention
+ .forward(&xs.apply(&self.attention_norm)?, emb, attention_mask)?;
+ let xs = (residual + xs)?;
+ let residual = &xs;
+ let xs = xs.apply(&self.ffn_norm)?.apply(&self.feed_forward)?;
+ xs + residual
+ }
+}
+
+#[derive(Debug, Clone)]
+struct Transformer {
+ layers: Vec<AttentionLayer>,
+}
+
+impl Transformer {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
+ let vb = vb.pp("layers");
+ for layer_idx in 0..cfg.num_hidden_layers {
+ let layer = AttentionLayer::new(cfg, vb.pp(layer_idx))?;
+ layers.push(layer)
+ }
+ Ok(Self { layers })
+ }
+
+ fn forward(
+ &self,
+ xs: &Tensor,
+ emb: &RotaryEmbedding,
+ attention_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for layer in self.layers.iter() {
+ xs = layer.forward(&xs, emb, attention_mask)?
+ }
+ Ok(xs)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct RotaryEmbedding {
+ cos: Tensor,
+ sin: Tensor,
+}
+
+impl RotaryEmbedding {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let dtype = vb.dtype();
+ let dev = vb.device();
+ let dim = cfg.head_dim();
+ let rope_theta = cfg.rope_theta as f32;
+ let max_patches_per_side = cfg.image_size / cfg.patch_size;
+ let freqs: Vec<_> = (0..dim)
+ .step_by(2)
+ .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
+ .collect();
+ let freqs_h = freqs.iter().step_by(2).copied().collect::<Vec<_>>();
+ let freqs_h = Tensor::new(freqs_h, dev)?;
+ let freqs_w = freqs.iter().skip(1).step_by(2).copied().collect::<Vec<_>>();
+ let freqs_w = Tensor::new(freqs_w, dev)?;
+ let h = Tensor::arange(0u32, max_patches_per_side as u32, dev)?.to_dtype(DType::F32)?;
+ let w = Tensor::arange(0u32, max_patches_per_side as u32, dev)?.to_dtype(DType::F32)?;
+ let freqs_h = h.unsqueeze(1)?.matmul(&freqs_h.unsqueeze(0)?)?;
+ let freqs_w = w.unsqueeze(1)?.matmul(&freqs_w.unsqueeze(0)?)?;
+ let inv_freq = Tensor::cat(
+ &[
+ freqs_h.unsqueeze(1)?.repeat((1, max_patches_per_side, 1))?,
+ freqs_w.unsqueeze(0)?.repeat((max_patches_per_side, 1, 1))?,
+ ],
+ D::Minus1,
+ )?
+ .reshape(((), dim / 2))?;
+ let cos = inv_freq.cos()?.to_dtype(dtype)?;
+ let sin = inv_freq.sin()?.to_dtype(dtype)?;
+ Ok(Self { cos, sin })
+ }
+
+ fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
+ let (_b_sz, _h, _seq_len, _n_embd) = q.dims4()?;
+ let cos = &self.cos;
+ let sin = &self.sin;
+ let q_embed = candle_nn::rotary_emb::rope(q, cos, sin)?;
+ let k_embed = candle_nn::rotary_emb::rope(k, cos, sin)?;
+ Ok((q_embed, k_embed))
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct Model {
+ patch_conv: candle_nn::Conv2d,
+ ln_pre: RmsNorm,
+ transformer: Transformer,
+ patch_positional_embedding: RotaryEmbedding,
+}
+
+impl Model {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let conv2d_cfg = candle_nn::Conv2dConfig {
+ stride: cfg.patch_size,
+ ..Default::default()
+ };
+ let patch_conv = candle_nn::conv2d_no_bias(
+ cfg.num_channels,
+ cfg.hidden_size,
+ cfg.patch_size,
+ conv2d_cfg,
+ vb.pp("patch_conv"),
+ )?;
+ let ln_pre = candle_nn::rms_norm(cfg.hidden_size, 1e-5, vb.pp("ln_pre"))?;
+ let transformer = Transformer::new(cfg, vb.pp("transformer"))?;
+ let patch_positional_embedding =
+ RotaryEmbedding::new(cfg, vb.pp("patch_positional_embedding"))?;
+ Ok(Self {
+ patch_conv,
+ ln_pre,
+ transformer,
+ patch_positional_embedding,
+ })
+ }
+}
+
+impl Module for Model {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let patch_embeds = xs.apply(&self.patch_conv)?;
+ let patch_embeds = patch_embeds.flatten_from(2)?.t()?.apply(&self.ln_pre)?;
+ self.transformer
+ .forward(&patch_embeds, &self.patch_positional_embedding, None)
+ }
+}