diff options
author | Jani Monoses <jani.monoses@gmail.com> | 2024-08-29 16:38:58 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-29 15:38:58 +0200 |
commit | 86613c00e216750f32a326dbff5cc993d5e0067e (patch) | |
tree | 3b51fc0d3a1a930b57a86d6f1fd94041f68187cf /candle-transformers | |
parent | 29e25c458da869808502a88024f57b1c86efa090 (diff) | |
download | candle-86613c00e216750f32a326dbff5cc993d5e0067e.tar.gz candle-86613c00e216750f32a326dbff5cc993d5e0067e.tar.bz2 candle-86613c00e216750f32a326dbff5cc993d5e0067e.zip |
MobileCLIP models S1 and S2 (#2454)
* Allow loading images with given std and mean
* OpenCLIP text encoder component
* Two MobileCLIP models
* Clippy fixes.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/mobileclip.rs | 89 | ||||
-rw-r--r-- | candle-transformers/src/models/mod.rs | 2 | ||||
-rw-r--r-- | candle-transformers/src/models/openclip/mod.rs | 1 | ||||
-rw-r--r-- | candle-transformers/src/models/openclip/text_model.rs | 266 |
4 files changed, 358 insertions, 0 deletions
diff --git a/candle-transformers/src/models/mobileclip.rs b/candle-transformers/src/models/mobileclip.rs new file mode 100644 index 00000000..4953d835 --- /dev/null +++ b/candle-transformers/src/models/mobileclip.rs @@ -0,0 +1,89 @@ +use super::fastvit; +use super::openclip::text_model; +use candle::{Result, Tensor, D}; +use candle_nn::{Func, VarBuilder}; + +#[derive(Clone, Debug)] +pub struct MobileClipModel { + text_model: text_model::OpenClipTextTransformer, + vision_model: Func<'static>, + text_projection: Tensor, + logit_scale: Tensor, +} + +#[derive(Clone, Debug)] +pub struct MobileClipConfig { + pub text_config: text_model::Config, + pub vision_config: fastvit::Config, + pub image_size: usize, +} + +impl MobileClipConfig { + pub fn s1() -> Self { + let text_config = text_model::Config::vit_base_patch32(); + let vision_config = fastvit::Config::mci1(); + + Self { + text_config, + vision_config, + image_size: 256, + } + } + pub fn s2() -> Self { + let text_config = text_model::Config::vit_base_patch32(); + let vision_config = fastvit::Config::mci2(); + + Self { + text_config, + vision_config, + image_size: 256, + } + } +} + +impl MobileClipModel { + pub fn new(vs: VarBuilder, c: &MobileClipConfig) -> Result<Self> { + let vision_model = fastvit::fastvit(&c.vision_config, 512, vs.pp("visual.trunk"))?; + let text_model = text_model::OpenClipTextTransformer::new(vs.pp("text"), &c.text_config)?; + + let text_projection = vs.get( + (c.text_config.embed_dim, c.text_config.projection_dim), + "text.text_projection", + )?; + + let logit_scale = vs.get(&[], "logit_scale")?; + Ok(Self { + text_model, + vision_model, + text_projection, + logit_scale, + }) + } + + pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> { + input_ids + .apply(&self.text_model)? + .matmul(&self.text_projection) + } + + pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> { + pixel_values.apply(&self.vision_model) + } + + pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> { + let image_features = self.get_image_features(pixel_values)?; + let text_features = self.get_text_features(input_ids)?; + let image_features_normalized = div_l2_norm(&image_features)?; + let text_features_normalized = div_l2_norm(&text_features)?; + let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?; + let logit_scale = self.logit_scale.exp()?; + let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?; + let logits_per_image = logits_per_text.t()?; + Ok((logits_per_text, logits_per_image)) + } +} + +pub fn div_l2_norm(v: &Tensor) -> Result<Tensor> { + let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?; + v.broadcast_div(&l2_norm) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index a234b8bb..9f7856ea 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -37,11 +37,13 @@ pub mod mistral; pub mod mixformer; pub mod mixtral; pub mod mmdit; +pub mod mobileclip; pub mod mobilenetv4; pub mod mobileone; pub mod moondream; pub mod mpt; pub mod olmo; +pub mod openclip; pub mod parler_tts; pub mod persimmon; pub mod phi; diff --git a/candle-transformers/src/models/openclip/mod.rs b/candle-transformers/src/models/openclip/mod.rs new file mode 100644 index 00000000..ee2a501d --- /dev/null +++ b/candle-transformers/src/models/openclip/mod.rs @@ -0,0 +1 @@ +pub mod text_model; diff --git a/candle-transformers/src/models/openclip/text_model.rs b/candle-transformers/src/models/openclip/text_model.rs new file mode 100644 index 00000000..7b444e79 --- /dev/null +++ b/candle-transformers/src/models/openclip/text_model.rs @@ -0,0 +1,266 @@ +//! Text encoder as used in most OpenCLIP pretrained models +//! https://github.com/mlfoundations/open_clip + +use candle::{DType, IndexOp, Result, Tensor, D}; +use candle_nn::{ + embedding, layer_norm, linear, ops::softmax_last_dim, Embedding, LayerNorm, Linear, Module, + VarBuilder, +}; + +#[derive(Debug, Clone)] +pub struct Config { + pub vocab_size: usize, + pub embed_dim: usize, + pub intermediate_size: usize, + pub max_position_embeddings: usize, + pub pad_with: Option<String>, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub projection_dim: usize, +} + +impl Config { + pub fn vit_base_patch32() -> Self { + Self { + vocab_size: 49408, + embed_dim: 512, + intermediate_size: 2048, + max_position_embeddings: 77, + pad_with: None, + num_hidden_layers: 12, + num_attention_heads: 8, + projection_dim: 512, + } + } +} + +#[derive(Clone, Debug)] +struct TextEmbeddings { + token_embedding: Embedding, + position_embedding: Tensor, +} + +impl TextEmbeddings { + fn new(vs: VarBuilder, c: &Config) -> Result<Self> { + let token_embedding = embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?; + let position_embedding = vs.get( + (c.max_position_embeddings, c.embed_dim), + "positional_embedding", + )?; + Ok(TextEmbeddings { + token_embedding, + position_embedding, + }) + } +} + +impl Module for TextEmbeddings { + fn forward(&self, input_ids: &Tensor) -> Result<Tensor> { + let seq_length = input_ids.dim(D::Minus1)?; + let inputs_embeds = self.token_embedding.forward(input_ids)?; + + let position_embedding = self.position_embedding.narrow(0, 0, seq_length)?; + + inputs_embeds.broadcast_add(&position_embedding) + } +} + +#[derive(Clone, Debug)] +struct Attention { + k_proj: candle_nn::Linear, + v_proj: candle_nn::Linear, + q_proj: candle_nn::Linear, + out_proj: Linear, + head_dim: usize, + scale: f64, + num_attention_heads: usize, +} + +impl Attention { + fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> { + let embed_dim = c.embed_dim; + let num_attention_heads = c.num_attention_heads; + + let in_proj_weights = vs + .get((embed_dim * 3, embed_dim), "in_proj_weight")? + .chunk(3, 0)?; + let (q_w, k_w, v_w) = ( + &in_proj_weights[0], + &in_proj_weights[1], + &in_proj_weights[2], + ); + let in_proj_biases = vs.get(embed_dim * 3, "in_proj_bias")?.chunk(3, 0)?; + let (q_b, k_b, v_b) = (&in_proj_biases[0], &in_proj_biases[1], &in_proj_biases[2]); + + let q_proj = Linear::new(q_w.clone(), Some(q_b.clone())); + let k_proj = Linear::new(k_w.clone(), Some(k_b.clone())); + let v_proj = Linear::new(v_w.clone(), Some(v_b.clone())); + let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?; + let head_dim = embed_dim / num_attention_heads; + let scale = (head_dim as f64).powf(-0.5); + + Ok(Attention { + k_proj, + v_proj, + q_proj, + out_proj, + head_dim, + scale, + num_attention_heads, + }) + } + + fn shape_multihead(&self, xs: &Tensor, bsz: usize, seq_len: usize) -> Result<Tensor> { + xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()? + .to_dtype(DType::F32) + } + + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let in_dtype = xs.dtype(); + let (bsz, seq_len, embed_dim) = xs.dims3()?; + + let q = self.shape_multihead(&self.q_proj.forward(xs)?, bsz, seq_len)?; + let k = self.shape_multihead(&self.k_proj.forward(xs)?, bsz, seq_len)?; + let v = self.shape_multihead(&self.v_proj.forward(xs)?, bsz, seq_len)?; + let q = (q * self.scale)?; + + let attn_weights = q.matmul(&k.transpose(D::Minus1, D::Minus2)?)?; + + let attn_weights = softmax_last_dim(&attn_weights)?; + + let attn_output = attn_weights.matmul(&v)?.to_dtype(in_dtype)?; + let attn_output = attn_output + .transpose(1, 2)? + .contiguous()? + .reshape((bsz, seq_len, embed_dim))?; + let out = self.out_proj.forward(&attn_output)?; + Ok(out) + } +} + +#[derive(Clone, Debug)] +struct Mlp { + fc1: Linear, + fc2: Linear, +} + +impl Mlp { + fn new(vs: VarBuilder, c: &Config) -> Result<Self> { + let fc1 = linear(c.embed_dim, c.intermediate_size, vs.pp("c_fc"))?; + let fc2 = linear(c.intermediate_size, c.embed_dim, vs.pp("c_proj"))?; + + Ok(Mlp { fc1, fc2 }) + } +} + +impl Mlp { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs = self.fc1.forward(xs)?; + self.fc2.forward(&xs.gelu_erf()?) + } +} + +#[derive(Clone, Debug)] +struct EncoderLayer { + self_attn: Attention, + layer_norm1: LayerNorm, + mlp: Mlp, + layer_norm2: LayerNorm, +} + +impl EncoderLayer { + fn new(vs: VarBuilder, c: &Config) -> Result<Self> { + let self_attn = Attention::new(vs.pp("attn"), c)?; + let layer_norm1 = layer_norm(c.embed_dim, 1e-5, vs.pp("ln_1"))?; + let mlp = Mlp::new(vs.pp("mlp"), c)?; + let layer_norm2 = layer_norm(c.embed_dim, 1e-5, vs.pp("ln_2"))?; + + Ok(EncoderLayer { + self_attn, + layer_norm1, + mlp, + layer_norm2, + }) + } + + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let residual = xs; + let xs = self.layer_norm1.forward(xs)?; + let xs = self.self_attn.forward(&xs)?; + let xs = (xs + residual)?; + + let residual = &xs; + let xs = self.layer_norm2.forward(&xs)?; + let xs = self.mlp.forward(&xs)?; + let out = (xs + residual)?; + Ok(out) + } +} + +#[derive(Clone, Debug)] +pub struct Encoder { + layers: Vec<EncoderLayer>, +} + +impl Encoder { + pub fn new(vs: VarBuilder, c: &Config) -> Result<Self> { + let vs = vs.pp("resblocks"); + let mut layers: Vec<EncoderLayer> = Vec::new(); + for index in 0..c.num_hidden_layers { + let layer = EncoderLayer::new(vs.pp(index.to_string()), c)?; + layers.push(layer) + } + Ok(Encoder { layers }) + } + + pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = xs.clone(); + for layer in self.layers.iter() { + xs = layer.forward(&xs)?; + } + Ok(xs) + } +} + +/// A text transformer as used in CLIP variants. +#[derive(Clone, Debug)] +pub struct OpenClipTextTransformer { + embeddings: TextEmbeddings, + encoder: Encoder, + final_layer_norm: LayerNorm, +} + +impl OpenClipTextTransformer { + pub fn new(vs: VarBuilder, c: &Config) -> Result<Self> { + let embeddings = TextEmbeddings::new(vs.clone(), c)?; + let final_layer_norm = layer_norm(c.embed_dim, 1e-5, vs.pp("ln_final"))?; + let encoder = Encoder::new(vs.pp("transformer"), c)?; + Ok(OpenClipTextTransformer { + embeddings, + encoder, + final_layer_norm, + }) + } + + pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> { + let input_ids = self.embeddings.forward(input_ids)?; + let input_ids = self.encoder.forward(&input_ids)?; + self.final_layer_norm.forward(&input_ids) + } +} + +impl Module for OpenClipTextTransformer { + fn forward(&self, input_ids: &Tensor) -> Result<Tensor> { + let output = self.forward(input_ids)?; + let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?; + + let mut indices = Vec::new(); + for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::<i64>()?.iter().enumerate() { + let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?; + indices.push(index); + } + Tensor::cat(&indices, 0) + } +} |