diff options
author | chenwanqq <wanqi0605@qq.com> | 2024-06-03 17:54:09 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-03 11:54:09 +0200 |
commit | cd4d941ed10fd334333cf5793e311d2bef88a438 (patch) | |
tree | 02c4003cc3147986dd5f111ec58681930a7c6664 /candle-transformers | |
parent | 03344d3c19887f6e357d3667fc8e519dfd58b23a (diff) | |
download | candle-cd4d941ed10fd334333cf5793e311d2bef88a438.tar.gz candle-cd4d941ed10fd334333cf5793e311d2bef88a438.tar.bz2 candle-cd4d941ed10fd334333cf5793e311d2bef88a438.zip |
Add LLaVA support (#2234)
* first commit
* llava
* clippy and fmt
* some fixes
* minor fixes
* remove useless file
* refactor: Remove llava/constants.rs and update llava/mod.rs
* modify variable name
* modify code after clippy
* Minor tweaks.
---------
Co-authored-by: laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/clip/text_model.rs | 14 | ||||
-rw-r--r-- | candle-transformers/src/models/clip/vision_model.rs | 24 | ||||
-rw-r--r-- | candle-transformers/src/models/llama.rs | 22 | ||||
-rw-r--r-- | candle-transformers/src/models/llava/config.rs | 267 | ||||
-rw-r--r-- | candle-transformers/src/models/llava/mod.rs | 407 | ||||
-rw-r--r-- | candle-transformers/src/models/llava/utils.rs | 41 | ||||
-rw-r--r-- | candle-transformers/src/models/mod.rs | 1 |
7 files changed, 776 insertions, 0 deletions
diff --git a/candle-transformers/src/models/clip/text_model.rs b/candle-transformers/src/models/clip/text_model.rs index d3ba26ff..4e4b4c90 100644 --- a/candle-transformers/src/models/clip/text_model.rs +++ b/candle-transformers/src/models/clip/text_model.rs @@ -262,6 +262,20 @@ impl ClipEncoder { } Ok(xs) } + // required by LLaVA + pub fn output_hidden_states( + &self, + xs: &Tensor, + causal_attention_mask: Option<&Tensor>, + ) -> Result<Vec<Tensor>> { + let mut xs = xs.clone(); + let mut hidden_states = Vec::new(); + for layer in self.layers.iter() { + xs = layer.forward(&xs, causal_attention_mask)?; + hidden_states.push(xs.clone()); + } + Ok(hidden_states) + } } /// A CLIP transformer based model. diff --git a/candle-transformers/src/models/clip/vision_model.rs b/candle-transformers/src/models/clip/vision_model.rs index 88992434..e64cab16 100644 --- a/candle-transformers/src/models/clip/vision_model.rs +++ b/candle-transformers/src/models/clip/vision_model.rs @@ -46,6 +46,19 @@ impl ClipVisionConfig { patch_size: 32, } } + pub fn clip_vit_large_patch14_336() -> Self { + Self { + embed_dim: 1024, + activation: Activation::QuickGelu, + intermediate_size: 4096, + num_hidden_layers: 24, + num_attention_heads: 16, + projection_dim: 768, + num_channels: 3, + image_size: 336, + patch_size: 14, + } + } } // https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L112 @@ -130,6 +143,17 @@ impl ClipVisionTransformer { pre_layer_norm, }) } + // required by LLaVA + pub fn output_hidden_states(&self, pixel_values: &Tensor) -> Result<Vec<Tensor>> { + let hidden_states = pixel_values + .apply(&self.embeddings)? + .apply(&self.pre_layer_norm)?; + let mut result = self.encoder.output_hidden_states(&hidden_states, None)?; + let encoder_outputs = result.last().unwrap(); + let pooled_output = encoder_outputs.i((.., 0, ..))?; + result.push(self.final_layer_norm.forward(&pooled_output)?.clone()); + Ok(result) + } } impl Module for ClipVisionTransformer { diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index 57d2f593..a1f43d35 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -388,6 +388,28 @@ pub struct Llama { } impl Llama { + // required by LLaVA + pub fn embed(&self, x: &Tensor) -> Result<Tensor> { + self.wte.forward(x) + } + // required by LLaVA + pub fn forward_input_embed( + &self, + input_embed: &Tensor, + index_pos: usize, + cache: &mut Cache, + ) -> Result<Tensor> { + let (_, seq_len, _) = input_embed.dims3()?; + let mut x = input_embed.clone(); + for (block_idx, block) in self.blocks.iter().enumerate() { + x = block.forward(&x, index_pos, block_idx, cache)?; + } + let x = self.ln_f.forward(&x)?; + let x = x.i((.., seq_len - 1, ..))?.contiguous()?; + let logits = self.lm_head.forward(&x)?; + logits.to_dtype(DType::F32) + } + pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> { let (_b_sz, seq_len) = x.dims2()?; let mut x = self.wte.forward(x)?; diff --git a/candle-transformers/src/models/llava/config.rs b/candle-transformers/src/models/llava/config.rs new file mode 100644 index 00000000..d2d47003 --- /dev/null +++ b/candle-transformers/src/models/llava/config.rs @@ -0,0 +1,267 @@ +use std::collections::HashMap; + +use crate::models::{ + clip::{text_model::Activation, vision_model::ClipVisionConfig}, + llama::Config, +}; +use serde::{Deserialize, Serialize}; + +// original config from liuhaotian/llava +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct LLaVAConfig { + pub architectures: Vec<String>, + pub bos_token_id: usize, + pub eos_token_id: usize, + pub hidden_size: usize, + #[serde(default = "default_image_aspect_ratio")] + pub image_aspect_ratio: String, + pub image_crop_resolution: usize, + pub image_grid_pinpoints: Vec<(u32, u32)>, + pub image_split_resolution: usize, + pub intermediate_size: usize, + pub max_position_embeddings: usize, + pub mm_hidden_size: usize, + #[serde(default = "default_mm_patch_merge_type")] + pub mm_patch_merge_type: String, + pub mm_projector_type: String, + pub mm_use_im_start_end: bool, + pub mm_vision_select_feature: String, + pub mm_vision_select_layer: isize, + pub mm_vision_tower: Option<String>, + pub model_type: String, + pub num_attention_heads: usize, + pub num_hidden_layers: usize, + pub num_key_value_heads: usize, + pub pad_token_id: usize, + pub rms_norm_eps: f32, + pub rope_theta: f32, + pub tokenizer_model_max_length: Option<usize>, + pub torch_dtype: String, + pub use_cache: bool, + pub vocab_size: usize, + #[serde(default = "default_image_token_index")] + pub image_token_index: isize, + #[serde(default = "default_hf")] + pub hf: bool, +} + +fn default_hf() -> bool { + false +} + +fn default_image_token_index() -> isize { + -200 +} + +fn default_mm_patch_merge_type() -> String { + "flat".to_string() +} + +fn default_image_aspect_ratio() -> String { + "square".to_string() +} + +impl LLaVAConfig { + pub fn to_llama_config(&self) -> Config { + Config { + hidden_size: self.hidden_size, + intermediate_size: self.intermediate_size, + vocab_size: self.vocab_size, + num_hidden_layers: self.num_hidden_layers, + num_attention_heads: self.num_attention_heads, + num_key_value_heads: self.num_key_value_heads, + rms_norm_eps: self.rms_norm_eps as f64, + rope_theta: self.rope_theta, + bos_token_id: Some(self.bos_token_id as u32), + eos_token_id: Some(self.eos_token_id as u32), + use_flash_attn: false, + } + } +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct HFLLaVATextConfig { + pub architectures: Vec<String>, + #[serde(default = "default_hidden_size")] + pub hidden_size: usize, + #[serde(default = "default_intermediate_size")] + pub intermediate_size: usize, + #[serde(default = "default_max_length")] + pub max_length: usize, + pub max_position_embeddings: usize, + pub model_type: String, + #[serde(default = "default_num_attention_heads")] + pub num_attention_heads: usize, + #[serde(default = "default_num_hidden_layers")] + pub num_hidden_layers: usize, + #[serde(default = "default_num_key_value_heads")] + pub num_key_value_heads: usize, + pub pad_token_id: usize, + pub rms_norm_eps: f32, + #[serde(default = "default_rope_theta")] + pub rope_theta: f32, + pub torch_dtype: String, + #[serde(default = "default_use_cache")] + pub use_cache: bool, + pub vocab_size: usize, +} + +fn default_num_hidden_layers() -> usize { + 32 +} + +fn default_use_cache() -> bool { + true +} + +fn default_hidden_size() -> usize { + 4096 +} + +fn default_intermediate_size() -> usize { + 11008 +} + +fn default_max_length() -> usize { + 4096 +} + +fn default_num_attention_heads() -> usize { + 32 +} + +fn default_num_key_value_heads() -> usize { + 32 +} + +fn default_rope_theta() -> f32 { + 10000.0 +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct HFLLaVAVisionConfig { + pub hidden_size: usize, + pub image_size: usize, + pub intermediate_size: usize, + pub model_type: String, + pub num_attention_heads: usize, + pub num_hidden_layers: usize, + pub patch_size: usize, + pub projection_dim: usize, + pub vocab_size: usize, +} + +// config from llava-v1.6-vicuna-7b-hf +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct HFLLaVAConfig { + pub architectures: Vec<String>, + pub ignore_index: isize, + pub image_grid_pinpoints: Vec<(u32, u32)>, + pub image_token_index: isize, + pub model_type: String, + pub projector_hidden_act: String, + pub text_config: HFLLaVATextConfig, + pub torch_dtype: String, + pub use_image_newline_parameter: bool, + pub vision_config: HFLLaVAVisionConfig, + pub vision_feature_layer: isize, + pub vision_feature_select_strategy: String, + pub vocab_size: usize, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct HFGenerationConfig { + pub bos_token_id: usize, + pub eos_token_id: usize, + #[serde(default = "default_max_length")] + pub max_length: usize, + pub pad_token_id: usize, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct HFPreProcessorConfig { + pub aspect_ratio_setting: String, + pub crop_size: HashMap<String, usize>, + pub do_center_crop: bool, + pub do_convert_rgb: bool, + pub do_normalize: bool, + pub do_rescale: bool, + pub do_resize: bool, + pub image_mean: Vec<f32>, + pub image_std: Vec<f32>, + pub resample: u32, + pub rescale_factor: f32, + pub size: HashMap<String, f32>, +} + +impl HFLLaVAConfig { + pub fn to_clip_vision_config(&self) -> ClipVisionConfig { + ClipVisionConfig { + embed_dim: self.vision_config.hidden_size, + activation: Activation::QuickGelu, + intermediate_size: self.vision_config.intermediate_size, + num_hidden_layers: self.vision_config.num_hidden_layers, + num_attention_heads: self.vision_config.num_attention_heads, + projection_dim: self.vision_config.projection_dim, + num_channels: 3, + image_size: self.vision_config.image_size, + patch_size: self.vision_config.patch_size, + } + } + fn map_projector_type(s: &str) -> String { + if s == "gelu" { + "mlp2x_gelu".to_string() + } else { + s.to_string() + } + } + + fn map_select_feature(s: &str) -> String { + if s == "default" { + "patch".to_string() + } else { + "cls_patch".to_string() + } + } + + pub fn to_llava_config( + &self, + generation_config: &HFGenerationConfig, + preprocessor_config: &HFPreProcessorConfig, + ) -> LLaVAConfig { + LLaVAConfig { + hf: true, + architectures: self.architectures.clone(), + bos_token_id: generation_config.bos_token_id, + eos_token_id: generation_config.eos_token_id, + hidden_size: self.text_config.hidden_size, + image_aspect_ratio: preprocessor_config.aspect_ratio_setting.clone(), + image_crop_resolution: 224, + image_grid_pinpoints: self.image_grid_pinpoints.clone(), + image_split_resolution: 224, + intermediate_size: self.text_config.intermediate_size, + max_position_embeddings: self.text_config.max_position_embeddings, + mm_hidden_size: 1024, + mm_patch_merge_type: "spatial_unpad".to_string(), + mm_projector_type: Self::map_projector_type(&self.projector_hidden_act), + mm_use_im_start_end: false, + mm_vision_select_feature: Self::map_select_feature( + &self.vision_feature_select_strategy, + ), + mm_vision_select_layer: self.vision_feature_layer, + mm_vision_tower: None, + model_type: self.model_type.clone(), + num_attention_heads: self.text_config.num_attention_heads, + num_hidden_layers: self.text_config.num_hidden_layers, + num_key_value_heads: self.text_config.num_key_value_heads, + pad_token_id: self.text_config.pad_token_id, + rms_norm_eps: self.text_config.rms_norm_eps, + rope_theta: self.text_config.rope_theta, + tokenizer_model_max_length: Some(4096), + torch_dtype: self.torch_dtype.clone(), + use_cache: self.text_config.use_cache, + vocab_size: self.vocab_size, + image_token_index: self.image_token_index, + } + } +} diff --git a/candle-transformers/src/models/llava/mod.rs b/candle-transformers/src/models/llava/mod.rs new file mode 100644 index 00000000..caa8737a --- /dev/null +++ b/candle-transformers/src/models/llava/mod.rs @@ -0,0 +1,407 @@ +pub mod config; +pub mod utils; + +use crate::models::clip::vision_model::{ClipVisionConfig, ClipVisionTransformer}; +use crate::models::llama::{Cache, Llama}; +use crate::models::with_tracing::linear; + +use candle::{bail, Device, IndexOp, Result, Tensor}; +use candle_nn::{seq, Activation, Module, Sequential, VarBuilder}; +use fancy_regex::Regex; +use utils::get_anyres_image_grid_shape; + +use config::LLaVAConfig; + +fn mlp_gelu_match(mm_projector_type: &str) -> Option<usize> { + let mlp_gelu_regex = Regex::new(r"^mlp(\d+)x_gelu$").unwrap(); + + if let Ok(Some(captures)) = mlp_gelu_regex.captures(mm_projector_type) { + if let Some(match_str) = captures.get(1) { + let match_str = match_str.as_str(); + match_str.parse::<usize>().ok() + } else { + None + } + } else { + None + } +} + +fn unpad_image(tensor: &Tensor, original_size: &(u32, u32)) -> Result<Tensor> { + assert_eq!(tensor.dims().len(), 3); + let (original_width, original_height) = *original_size; + let tensor_dims = tensor.dims(); + let current_height = tensor_dims[1]; + let current_width = tensor_dims[2]; + let original_aspect_ratio = (original_width as f32) / (original_height as f32); + let current_aspect_ratio = (current_width as f32) / (current_height as f32); + if original_aspect_ratio > current_aspect_ratio { + let scale_factor = (current_width as f32) / (original_width as f32); + let new_height = (original_height as f32 * scale_factor).floor() as usize; + let padding = (current_height - new_height) / 2; + tensor.i((.., padding..current_width - padding, ..)) + } else { + let scale_factor = (current_height as f32) / (original_height as f32); + let new_width = (original_width as f32 * scale_factor).floor() as usize; + let padding = (current_width - new_width) / 2; + tensor.i((.., .., padding..current_width - padding)) + } +} + +pub struct IdentityMap {} + +impl Module for IdentityMap { + fn forward(&self, x: &Tensor) -> Result<Tensor> { + Ok(x.clone()) + } +} + +pub struct MMProjector { + pub modules: Sequential, +} + +impl MMProjector { + pub fn load(vb: &VarBuilder, config: &LLaVAConfig) -> Result<Self> { + if config.mm_projector_type == "linear" { + let vb_prefix = if config.hf { + "multi_modal_projector.linear_1" + } else { + "model.mm_projector.0" + }; + let linear = linear(config.mm_hidden_size, config.hidden_size, vb.pp(vb_prefix))?; + let modules = seq().add(linear); + Ok(Self { modules }) + } else if let Some(mlp_depth) = mlp_gelu_match(&config.mm_projector_type) { + let modules = if config.hf { + let mut modules = seq().add(linear( + config.mm_hidden_size, + config.hidden_size, + vb.pp("multi_modal_projector.linear_1"), + )?); + for i in 1..mlp_depth { + modules = modules.add(Activation::Gelu).add(linear( + config.hidden_size, + config.hidden_size, + vb.pp(format!("multi_modal_projector.linear_{}", i + 1)), + )?); + } + modules + } else { + let mut modules = seq().add(linear( + config.mm_hidden_size, + config.hidden_size, + vb.pp("model.mm_projector.0"), + )?); + for i in 1..mlp_depth { + modules = modules.add(Activation::Gelu).add(linear( + config.hidden_size, + config.hidden_size, + vb.pp(format!("model.mm_projector.{}", i * 2)), + )?); + } + modules + }; + Ok(Self { modules }) + } else if config.mm_projector_type == "identity" { + Ok(Self { + modules: seq().add(IdentityMap {}), + }) + } else { + bail!( + "Unsupported MM projector type: {}", + config.mm_projector_type + ) + } + } + + pub fn forward(&self, x: &Tensor) -> Result<Tensor> { + self.modules.forward(x) + } +} + +pub struct ClipVisionTower { + model: ClipVisionTransformer, + select_layer: isize, + select_feature_method: String, + pub config: ClipVisionConfig, +} + +impl ClipVisionTower { + pub fn new( + vb: VarBuilder, + select_layer: isize, + select_feature_method: &str, + config: &Option<ClipVisionConfig>, + ) -> Result<Self> { + let config = if config.is_none() { + ClipVisionConfig::clip_vit_large_patch14_336() + } else { + config.clone().unwrap() + }; + let select_layer = match select_layer { + -1 | -2 => select_layer, + _ => bail!("Unsupported select layer: {}", select_layer), + }; + let model = ClipVisionTransformer::new(vb, &config)?; + Ok(Self { + model, + select_layer, + select_feature_method: select_feature_method.to_string(), + config, + }) + } + + pub fn forward(&self, x: &Tensor) -> Result<Tensor> { + let result = self.model.output_hidden_states(x)?; + let index = result.len() as isize + self.select_layer; + let result = result[index as usize].clone(); + if self.select_feature_method == "cls_patch" { + Ok(result) + } else { + result.i((.., 1..)) + } + } + + pub fn num_patches_per_side(&self) -> usize { + self.config.image_size / self.config.patch_size + } +} + +pub struct LLaVA { + pub clip_vision_tower: ClipVisionTower, + pub image_newline: Tensor, + pub mm_projector: MMProjector, + pub llama: Llama, + config: LLaVAConfig, + device: Device, +} + +impl LLaVA { + pub fn load( + vb: VarBuilder, + config: &LLaVAConfig, + clip_vision_config: Option<ClipVisionConfig>, + ) -> Result<Self> { + let device = vb.device().clone(); + let llama_config = config.to_llama_config(); + let mm_projector = MMProjector::load(&vb, config)?; + let (clip_vision_tower, image_newline, llama) = if config.hf { + ( + ClipVisionTower::new( + vb.pp("vision_tower.vision_model"), + config.mm_vision_select_layer, + &config.mm_vision_select_feature, + &clip_vision_config, + )?, + vb.get(&[config.hidden_size], "image_newline")? + .to_device(&device)?, + Llama::load(vb.pp("language_model"), &llama_config)?, + ) + } else { + ( + ClipVisionTower::new( + vb.pp("model.vision_tower.vision_tower.vision_model"), + config.mm_vision_select_layer, + &config.mm_vision_select_feature, + &clip_vision_config, + )?, + vb.get(&[config.hidden_size], "model.image_newline")? + .to_device(&device)?, + Llama::load(vb, &llama_config)?, + ) + }; + Ok(Self { + clip_vision_tower, + image_newline, + mm_projector, + llama, + config: (*config).clone(), + device, + }) + } + + pub fn encode_images(&self, x: &Tensor) -> Result<Tensor> { + let image_features = self.clip_vision_tower.forward(x)?; + let image_features = self.mm_projector.forward(&image_features)?; + Ok(image_features) + } + // currently only for single image, 4 dim tensor + pub fn prepare_inputs_labels_for_multimodal( + &self, + input_ids: &Tensor, + images: &[Tensor], + image_sizes: &[(u32, u32)], + ) -> Result<Tensor> { + //TODO: process of multiple images/ new line + // 576: 336(input size)/14(patch size)=24 24*24+1(class)=577 577-1=576 + let concat_images = Tensor::cat(images, 0)?; + let image_features_together = self.encode_images(&concat_images)?; + let split_sizes = images + .iter() + .map(|x| x.shape().dims()[0]) + .collect::<Vec<usize>>(); + // can be replaced by split + let mut index_pos = 0; + let mut image_features = Vec::new(); + for split_size in split_sizes.iter() { + image_features.push(image_features_together.i(index_pos..index_pos + (*split_size))?); + index_pos += *split_size; + } + let mm_patch_merge_type = &self.config.mm_patch_merge_type; + let image_aspect_ratio = &self.config.image_aspect_ratio; + + let image_features = if mm_patch_merge_type == "flat" { + image_features + .iter() + .map(|x| x.flatten(0, 1).unwrap()) + .collect::<Vec<Tensor>>() + } else if mm_patch_merge_type.starts_with("spatial") { + let mut new_image_features = Vec::new(); + for (image_idx, image_feature) in image_features.iter().enumerate() { + let new_image_feature = if image_feature.dims()[0] > 1 { + let base_image_feature = image_feature.get(0).unwrap(); + let patch_image_feature = image_feature.i(1..).unwrap(); + let height = self.clip_vision_tower.num_patches_per_side(); + let width = height; + assert_eq!(height * width, base_image_feature.dims()[0]); + let image_size = image_sizes[image_idx]; + let new_image_feature = if image_aspect_ratio == "anyres" { + let (num_patch_width, num_patch_height) = get_anyres_image_grid_shape( + image_size, + &self.config.image_grid_pinpoints, + self.clip_vision_tower.config.image_size as u32, + ); + patch_image_feature.reshape(( + num_patch_height as usize, + num_patch_width as usize, + height, + width, + (), + ))? + } else { + todo!("not implemented in original python LLaVA yet") + }; + let new_image_feature = if mm_patch_merge_type.contains("unpad") { + let new_image_feature = new_image_feature + .permute((4, 0, 2, 1, 3))? + .flatten(1, 2)? + .flatten(2, 3)?; + let new_image_feature = unpad_image(&new_image_feature, &image_size)?; + let new_image_feature_dims = new_image_feature.dims(); + let image_new_line = self + .image_newline + .reshape((self.config.hidden_size, 1, 1))? + .broadcast_as(( + new_image_feature_dims[0], + new_image_feature_dims[1], + 1, + ))?; + let new_image_feature = + Tensor::cat(&[new_image_feature, image_new_line], 2)?; + new_image_feature.flatten(1, 2)?.transpose(0, 1)? + } else { + new_image_feature.permute((0, 2, 1, 3, 4))?.flatten(0, 3)? + }; + Tensor::cat(&[base_image_feature, new_image_feature], 0)? + } else { + let new_image_feature = image_feature.get(0).unwrap(); + if mm_patch_merge_type.contains("unpad") { + Tensor::cat( + &[ + new_image_feature, + self.image_newline.clone().unsqueeze(0).unwrap(), + ], + 0, + ) + .unwrap() + } else { + new_image_feature + } + }; + new_image_features.push(new_image_feature); + } + new_image_features + } else { + bail!("Unexpected mm_patch_merge_type: {mm_patch_merge_type}") + }; + // can easily be replaced by nonzero if it is implemented in candle + let input_ids_vec = input_ids.squeeze(0)?.to_vec1::<i64>()?; + let mut image_indices = { + let mut image_indices = vec![0_i64]; + image_indices.extend( + input_ids_vec + .iter() + .enumerate() + .filter_map(|(i, x)| { + if *x == self.config.image_token_index as i64 { + Some(i as i64) + } else { + None + } + }) + .collect::<Vec<i64>>(), + ); + image_indices + }; + if image_indices.len() == 1 { + //no image, only [0], + return self.llama.embed(input_ids); + } + + let input_ids_noim = input_ids_vec + .iter() + .filter_map(|x| { + if *x != self.config.image_token_index as i64 { + Some(*x) + } else { + None + } + }) + .collect::<Vec<i64>>(); + let input_ids_noim_len = input_ids_noim.len(); + image_indices.push((input_ids_noim_len) as i64); + let input_ids_noim = Tensor::from_vec(input_ids_noim, input_ids_noim_len, &self.device)?; + let cur_input_embeds = self.llama.embed(&input_ids_noim)?; + // can be replace by split if it is implemented in candle + let input_embed_no_ims = { + let mut input_embeds = Vec::new(); + for i in 0..image_indices.len() - 1 { + let start = (image_indices[i]) as usize; + let end = image_indices[i + 1] as usize; + input_embeds.push(cur_input_embeds.i((start..end, ..))?) + } + input_embeds + }; + + let mut cur_new_input_embeds = Vec::new(); + for (i, image_feature) in image_features.iter().enumerate() { + cur_new_input_embeds.push(input_embed_no_ims[i].clone()); + cur_new_input_embeds.push(image_feature.clone()); + } + cur_new_input_embeds.push(input_embed_no_ims[image_features.len()].clone()); + let new_input_embeds = Tensor::cat(&cur_new_input_embeds, 0)?; + //trancate + let new_input_embeds = + if let Some(tokenizer_model_max_length) = self.config.tokenizer_model_max_length { + let (new_input_embeds_length, _) = new_input_embeds.shape().dims2()?; + if new_input_embeds_length > tokenizer_model_max_length { + new_input_embeds.i((..tokenizer_model_max_length, ..))? + } else { + new_input_embeds + } + } else { + new_input_embeds + }; + new_input_embeds.unsqueeze(0) + } + + pub fn forward( + &self, + input_embeds: &Tensor, + position_id: usize, + cache: &mut Cache, + ) -> Result<Tensor> { + self.llama + .forward_input_embed(input_embeds, position_id, cache) + } +} diff --git a/candle-transformers/src/models/llava/utils.rs b/candle-transformers/src/models/llava/utils.rs new file mode 100644 index 00000000..3b4c18bb --- /dev/null +++ b/candle-transformers/src/models/llava/utils.rs @@ -0,0 +1,41 @@ +pub fn get_anyres_image_grid_shape( + image_size: (u32, u32), + grid_pinpoints: &[(u32, u32)], + patch_size: u32, +) -> (u32, u32) { + let (width, height) = select_best_resolution(image_size, grid_pinpoints); + (width / patch_size, height / patch_size) +} + +pub fn select_best_resolution( + original_size: (u32, u32), + possible_resolutions: &[(u32, u32)], +) -> (u32, u32) { + let (original_width, original_height) = original_size; + let mut best_fit = (0, 0); + let original_width_f = original_width as f32; + let original_height_f = original_height as f32; + let mut max_effective_resolution = 0_u32; + let mut min_wasted_resolution = u32::MAX; + for (width, height) in possible_resolutions { + let width_f = *width as f32; + let height_f = *height as f32; + let scale = (width_f / original_width_f).min(height_f / original_height_f); + let (downscaled_width, downscaled_height) = ( + (original_width_f * scale) as u32, + (original_height_f * scale) as u32, + ); + let effective_resolution = + std::cmp::min((*width) * (*height), downscaled_width * downscaled_height); + let wasted_resolution = (*width) * (*height) - effective_resolution; + if effective_resolution > max_effective_resolution + || (effective_resolution == max_effective_resolution + && wasted_resolution < min_wasted_resolution) + { + best_fit = (*width, *height); + max_effective_resolution = effective_resolution; + min_wasted_resolution = wasted_resolution; + } + } + best_fit +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index de2430a2..4628a3de 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -17,6 +17,7 @@ pub mod jina_bert; pub mod llama; pub mod llama2_c; pub mod llama2_c_weights; +pub mod llava; pub mod mamba; pub mod marian; pub mod metavoice; |