summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorchenwanqq <wanqi0605@qq.com>2024-06-03 17:54:09 +0800
committerGitHub <noreply@github.com>2024-06-03 11:54:09 +0200
commitcd4d941ed10fd334333cf5793e311d2bef88a438 (patch)
tree02c4003cc3147986dd5f111ec58681930a7c6664 /candle-transformers
parent03344d3c19887f6e357d3667fc8e519dfd58b23a (diff)
downloadcandle-cd4d941ed10fd334333cf5793e311d2bef88a438.tar.gz
candle-cd4d941ed10fd334333cf5793e311d2bef88a438.tar.bz2
candle-cd4d941ed10fd334333cf5793e311d2bef88a438.zip
Add LLaVA support (#2234)
* first commit * llava * clippy and fmt * some fixes * minor fixes * remove useless file * refactor: Remove llava/constants.rs and update llava/mod.rs * modify variable name * modify code after clippy * Minor tweaks. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/clip/text_model.rs14
-rw-r--r--candle-transformers/src/models/clip/vision_model.rs24
-rw-r--r--candle-transformers/src/models/llama.rs22
-rw-r--r--candle-transformers/src/models/llava/config.rs267
-rw-r--r--candle-transformers/src/models/llava/mod.rs407
-rw-r--r--candle-transformers/src/models/llava/utils.rs41
-rw-r--r--candle-transformers/src/models/mod.rs1
7 files changed, 776 insertions, 0 deletions
diff --git a/candle-transformers/src/models/clip/text_model.rs b/candle-transformers/src/models/clip/text_model.rs
index d3ba26ff..4e4b4c90 100644
--- a/candle-transformers/src/models/clip/text_model.rs
+++ b/candle-transformers/src/models/clip/text_model.rs
@@ -262,6 +262,20 @@ impl ClipEncoder {
}
Ok(xs)
}
+ // required by LLaVA
+ pub fn output_hidden_states(
+ &self,
+ xs: &Tensor,
+ causal_attention_mask: Option<&Tensor>,
+ ) -> Result<Vec<Tensor>> {
+ let mut xs = xs.clone();
+ let mut hidden_states = Vec::new();
+ for layer in self.layers.iter() {
+ xs = layer.forward(&xs, causal_attention_mask)?;
+ hidden_states.push(xs.clone());
+ }
+ Ok(hidden_states)
+ }
}
/// A CLIP transformer based model.
diff --git a/candle-transformers/src/models/clip/vision_model.rs b/candle-transformers/src/models/clip/vision_model.rs
index 88992434..e64cab16 100644
--- a/candle-transformers/src/models/clip/vision_model.rs
+++ b/candle-transformers/src/models/clip/vision_model.rs
@@ -46,6 +46,19 @@ impl ClipVisionConfig {
patch_size: 32,
}
}
+ pub fn clip_vit_large_patch14_336() -> Self {
+ Self {
+ embed_dim: 1024,
+ activation: Activation::QuickGelu,
+ intermediate_size: 4096,
+ num_hidden_layers: 24,
+ num_attention_heads: 16,
+ projection_dim: 768,
+ num_channels: 3,
+ image_size: 336,
+ patch_size: 14,
+ }
+ }
}
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L112
@@ -130,6 +143,17 @@ impl ClipVisionTransformer {
pre_layer_norm,
})
}
+ // required by LLaVA
+ pub fn output_hidden_states(&self, pixel_values: &Tensor) -> Result<Vec<Tensor>> {
+ let hidden_states = pixel_values
+ .apply(&self.embeddings)?
+ .apply(&self.pre_layer_norm)?;
+ let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
+ let encoder_outputs = result.last().unwrap();
+ let pooled_output = encoder_outputs.i((.., 0, ..))?;
+ result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
+ Ok(result)
+ }
}
impl Module for ClipVisionTransformer {
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs
index 57d2f593..a1f43d35 100644
--- a/candle-transformers/src/models/llama.rs
+++ b/candle-transformers/src/models/llama.rs
@@ -388,6 +388,28 @@ pub struct Llama {
}
impl Llama {
+ // required by LLaVA
+ pub fn embed(&self, x: &Tensor) -> Result<Tensor> {
+ self.wte.forward(x)
+ }
+ // required by LLaVA
+ pub fn forward_input_embed(
+ &self,
+ input_embed: &Tensor,
+ index_pos: usize,
+ cache: &mut Cache,
+ ) -> Result<Tensor> {
+ let (_, seq_len, _) = input_embed.dims3()?;
+ let mut x = input_embed.clone();
+ for (block_idx, block) in self.blocks.iter().enumerate() {
+ x = block.forward(&x, index_pos, block_idx, cache)?;
+ }
+ let x = self.ln_f.forward(&x)?;
+ let x = x.i((.., seq_len - 1, ..))?.contiguous()?;
+ let logits = self.lm_head.forward(&x)?;
+ logits.to_dtype(DType::F32)
+ }
+
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
let (_b_sz, seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?;
diff --git a/candle-transformers/src/models/llava/config.rs b/candle-transformers/src/models/llava/config.rs
new file mode 100644
index 00000000..d2d47003
--- /dev/null
+++ b/candle-transformers/src/models/llava/config.rs
@@ -0,0 +1,267 @@
+use std::collections::HashMap;
+
+use crate::models::{
+ clip::{text_model::Activation, vision_model::ClipVisionConfig},
+ llama::Config,
+};
+use serde::{Deserialize, Serialize};
+
+// original config from liuhaotian/llava
+#[derive(Serialize, Deserialize, Debug, Clone)]
+pub struct LLaVAConfig {
+ pub architectures: Vec<String>,
+ pub bos_token_id: usize,
+ pub eos_token_id: usize,
+ pub hidden_size: usize,
+ #[serde(default = "default_image_aspect_ratio")]
+ pub image_aspect_ratio: String,
+ pub image_crop_resolution: usize,
+ pub image_grid_pinpoints: Vec<(u32, u32)>,
+ pub image_split_resolution: usize,
+ pub intermediate_size: usize,
+ pub max_position_embeddings: usize,
+ pub mm_hidden_size: usize,
+ #[serde(default = "default_mm_patch_merge_type")]
+ pub mm_patch_merge_type: String,
+ pub mm_projector_type: String,
+ pub mm_use_im_start_end: bool,
+ pub mm_vision_select_feature: String,
+ pub mm_vision_select_layer: isize,
+ pub mm_vision_tower: Option<String>,
+ pub model_type: String,
+ pub num_attention_heads: usize,
+ pub num_hidden_layers: usize,
+ pub num_key_value_heads: usize,
+ pub pad_token_id: usize,
+ pub rms_norm_eps: f32,
+ pub rope_theta: f32,
+ pub tokenizer_model_max_length: Option<usize>,
+ pub torch_dtype: String,
+ pub use_cache: bool,
+ pub vocab_size: usize,
+ #[serde(default = "default_image_token_index")]
+ pub image_token_index: isize,
+ #[serde(default = "default_hf")]
+ pub hf: bool,
+}
+
+fn default_hf() -> bool {
+ false
+}
+
+fn default_image_token_index() -> isize {
+ -200
+}
+
+fn default_mm_patch_merge_type() -> String {
+ "flat".to_string()
+}
+
+fn default_image_aspect_ratio() -> String {
+ "square".to_string()
+}
+
+impl LLaVAConfig {
+ pub fn to_llama_config(&self) -> Config {
+ Config {
+ hidden_size: self.hidden_size,
+ intermediate_size: self.intermediate_size,
+ vocab_size: self.vocab_size,
+ num_hidden_layers: self.num_hidden_layers,
+ num_attention_heads: self.num_attention_heads,
+ num_key_value_heads: self.num_key_value_heads,
+ rms_norm_eps: self.rms_norm_eps as f64,
+ rope_theta: self.rope_theta,
+ bos_token_id: Some(self.bos_token_id as u32),
+ eos_token_id: Some(self.eos_token_id as u32),
+ use_flash_attn: false,
+ }
+ }
+}
+
+#[derive(Serialize, Deserialize, Debug, Clone)]
+pub struct HFLLaVATextConfig {
+ pub architectures: Vec<String>,
+ #[serde(default = "default_hidden_size")]
+ pub hidden_size: usize,
+ #[serde(default = "default_intermediate_size")]
+ pub intermediate_size: usize,
+ #[serde(default = "default_max_length")]
+ pub max_length: usize,
+ pub max_position_embeddings: usize,
+ pub model_type: String,
+ #[serde(default = "default_num_attention_heads")]
+ pub num_attention_heads: usize,
+ #[serde(default = "default_num_hidden_layers")]
+ pub num_hidden_layers: usize,
+ #[serde(default = "default_num_key_value_heads")]
+ pub num_key_value_heads: usize,
+ pub pad_token_id: usize,
+ pub rms_norm_eps: f32,
+ #[serde(default = "default_rope_theta")]
+ pub rope_theta: f32,
+ pub torch_dtype: String,
+ #[serde(default = "default_use_cache")]
+ pub use_cache: bool,
+ pub vocab_size: usize,
+}
+
+fn default_num_hidden_layers() -> usize {
+ 32
+}
+
+fn default_use_cache() -> bool {
+ true
+}
+
+fn default_hidden_size() -> usize {
+ 4096
+}
+
+fn default_intermediate_size() -> usize {
+ 11008
+}
+
+fn default_max_length() -> usize {
+ 4096
+}
+
+fn default_num_attention_heads() -> usize {
+ 32
+}
+
+fn default_num_key_value_heads() -> usize {
+ 32
+}
+
+fn default_rope_theta() -> f32 {
+ 10000.0
+}
+
+#[derive(Serialize, Deserialize, Debug, Clone)]
+pub struct HFLLaVAVisionConfig {
+ pub hidden_size: usize,
+ pub image_size: usize,
+ pub intermediate_size: usize,
+ pub model_type: String,
+ pub num_attention_heads: usize,
+ pub num_hidden_layers: usize,
+ pub patch_size: usize,
+ pub projection_dim: usize,
+ pub vocab_size: usize,
+}
+
+// config from llava-v1.6-vicuna-7b-hf
+#[derive(Serialize, Deserialize, Debug, Clone)]
+pub struct HFLLaVAConfig {
+ pub architectures: Vec<String>,
+ pub ignore_index: isize,
+ pub image_grid_pinpoints: Vec<(u32, u32)>,
+ pub image_token_index: isize,
+ pub model_type: String,
+ pub projector_hidden_act: String,
+ pub text_config: HFLLaVATextConfig,
+ pub torch_dtype: String,
+ pub use_image_newline_parameter: bool,
+ pub vision_config: HFLLaVAVisionConfig,
+ pub vision_feature_layer: isize,
+ pub vision_feature_select_strategy: String,
+ pub vocab_size: usize,
+}
+
+#[derive(Serialize, Deserialize, Debug, Clone)]
+pub struct HFGenerationConfig {
+ pub bos_token_id: usize,
+ pub eos_token_id: usize,
+ #[serde(default = "default_max_length")]
+ pub max_length: usize,
+ pub pad_token_id: usize,
+}
+
+#[derive(Serialize, Deserialize, Debug, Clone)]
+pub struct HFPreProcessorConfig {
+ pub aspect_ratio_setting: String,
+ pub crop_size: HashMap<String, usize>,
+ pub do_center_crop: bool,
+ pub do_convert_rgb: bool,
+ pub do_normalize: bool,
+ pub do_rescale: bool,
+ pub do_resize: bool,
+ pub image_mean: Vec<f32>,
+ pub image_std: Vec<f32>,
+ pub resample: u32,
+ pub rescale_factor: f32,
+ pub size: HashMap<String, f32>,
+}
+
+impl HFLLaVAConfig {
+ pub fn to_clip_vision_config(&self) -> ClipVisionConfig {
+ ClipVisionConfig {
+ embed_dim: self.vision_config.hidden_size,
+ activation: Activation::QuickGelu,
+ intermediate_size: self.vision_config.intermediate_size,
+ num_hidden_layers: self.vision_config.num_hidden_layers,
+ num_attention_heads: self.vision_config.num_attention_heads,
+ projection_dim: self.vision_config.projection_dim,
+ num_channels: 3,
+ image_size: self.vision_config.image_size,
+ patch_size: self.vision_config.patch_size,
+ }
+ }
+ fn map_projector_type(s: &str) -> String {
+ if s == "gelu" {
+ "mlp2x_gelu".to_string()
+ } else {
+ s.to_string()
+ }
+ }
+
+ fn map_select_feature(s: &str) -> String {
+ if s == "default" {
+ "patch".to_string()
+ } else {
+ "cls_patch".to_string()
+ }
+ }
+
+ pub fn to_llava_config(
+ &self,
+ generation_config: &HFGenerationConfig,
+ preprocessor_config: &HFPreProcessorConfig,
+ ) -> LLaVAConfig {
+ LLaVAConfig {
+ hf: true,
+ architectures: self.architectures.clone(),
+ bos_token_id: generation_config.bos_token_id,
+ eos_token_id: generation_config.eos_token_id,
+ hidden_size: self.text_config.hidden_size,
+ image_aspect_ratio: preprocessor_config.aspect_ratio_setting.clone(),
+ image_crop_resolution: 224,
+ image_grid_pinpoints: self.image_grid_pinpoints.clone(),
+ image_split_resolution: 224,
+ intermediate_size: self.text_config.intermediate_size,
+ max_position_embeddings: self.text_config.max_position_embeddings,
+ mm_hidden_size: 1024,
+ mm_patch_merge_type: "spatial_unpad".to_string(),
+ mm_projector_type: Self::map_projector_type(&self.projector_hidden_act),
+ mm_use_im_start_end: false,
+ mm_vision_select_feature: Self::map_select_feature(
+ &self.vision_feature_select_strategy,
+ ),
+ mm_vision_select_layer: self.vision_feature_layer,
+ mm_vision_tower: None,
+ model_type: self.model_type.clone(),
+ num_attention_heads: self.text_config.num_attention_heads,
+ num_hidden_layers: self.text_config.num_hidden_layers,
+ num_key_value_heads: self.text_config.num_key_value_heads,
+ pad_token_id: self.text_config.pad_token_id,
+ rms_norm_eps: self.text_config.rms_norm_eps,
+ rope_theta: self.text_config.rope_theta,
+ tokenizer_model_max_length: Some(4096),
+ torch_dtype: self.torch_dtype.clone(),
+ use_cache: self.text_config.use_cache,
+ vocab_size: self.vocab_size,
+ image_token_index: self.image_token_index,
+ }
+ }
+}
diff --git a/candle-transformers/src/models/llava/mod.rs b/candle-transformers/src/models/llava/mod.rs
new file mode 100644
index 00000000..caa8737a
--- /dev/null
+++ b/candle-transformers/src/models/llava/mod.rs
@@ -0,0 +1,407 @@
+pub mod config;
+pub mod utils;
+
+use crate::models::clip::vision_model::{ClipVisionConfig, ClipVisionTransformer};
+use crate::models::llama::{Cache, Llama};
+use crate::models::with_tracing::linear;
+
+use candle::{bail, Device, IndexOp, Result, Tensor};
+use candle_nn::{seq, Activation, Module, Sequential, VarBuilder};
+use fancy_regex::Regex;
+use utils::get_anyres_image_grid_shape;
+
+use config::LLaVAConfig;
+
+fn mlp_gelu_match(mm_projector_type: &str) -> Option<usize> {
+ let mlp_gelu_regex = Regex::new(r"^mlp(\d+)x_gelu$").unwrap();
+
+ if let Ok(Some(captures)) = mlp_gelu_regex.captures(mm_projector_type) {
+ if let Some(match_str) = captures.get(1) {
+ let match_str = match_str.as_str();
+ match_str.parse::<usize>().ok()
+ } else {
+ None
+ }
+ } else {
+ None
+ }
+}
+
+fn unpad_image(tensor: &Tensor, original_size: &(u32, u32)) -> Result<Tensor> {
+ assert_eq!(tensor.dims().len(), 3);
+ let (original_width, original_height) = *original_size;
+ let tensor_dims = tensor.dims();
+ let current_height = tensor_dims[1];
+ let current_width = tensor_dims[2];
+ let original_aspect_ratio = (original_width as f32) / (original_height as f32);
+ let current_aspect_ratio = (current_width as f32) / (current_height as f32);
+ if original_aspect_ratio > current_aspect_ratio {
+ let scale_factor = (current_width as f32) / (original_width as f32);
+ let new_height = (original_height as f32 * scale_factor).floor() as usize;
+ let padding = (current_height - new_height) / 2;
+ tensor.i((.., padding..current_width - padding, ..))
+ } else {
+ let scale_factor = (current_height as f32) / (original_height as f32);
+ let new_width = (original_width as f32 * scale_factor).floor() as usize;
+ let padding = (current_width - new_width) / 2;
+ tensor.i((.., .., padding..current_width - padding))
+ }
+}
+
+pub struct IdentityMap {}
+
+impl Module for IdentityMap {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ Ok(x.clone())
+ }
+}
+
+pub struct MMProjector {
+ pub modules: Sequential,
+}
+
+impl MMProjector {
+ pub fn load(vb: &VarBuilder, config: &LLaVAConfig) -> Result<Self> {
+ if config.mm_projector_type == "linear" {
+ let vb_prefix = if config.hf {
+ "multi_modal_projector.linear_1"
+ } else {
+ "model.mm_projector.0"
+ };
+ let linear = linear(config.mm_hidden_size, config.hidden_size, vb.pp(vb_prefix))?;
+ let modules = seq().add(linear);
+ Ok(Self { modules })
+ } else if let Some(mlp_depth) = mlp_gelu_match(&config.mm_projector_type) {
+ let modules = if config.hf {
+ let mut modules = seq().add(linear(
+ config.mm_hidden_size,
+ config.hidden_size,
+ vb.pp("multi_modal_projector.linear_1"),
+ )?);
+ for i in 1..mlp_depth {
+ modules = modules.add(Activation::Gelu).add(linear(
+ config.hidden_size,
+ config.hidden_size,
+ vb.pp(format!("multi_modal_projector.linear_{}", i + 1)),
+ )?);
+ }
+ modules
+ } else {
+ let mut modules = seq().add(linear(
+ config.mm_hidden_size,
+ config.hidden_size,
+ vb.pp("model.mm_projector.0"),
+ )?);
+ for i in 1..mlp_depth {
+ modules = modules.add(Activation::Gelu).add(linear(
+ config.hidden_size,
+ config.hidden_size,
+ vb.pp(format!("model.mm_projector.{}", i * 2)),
+ )?);
+ }
+ modules
+ };
+ Ok(Self { modules })
+ } else if config.mm_projector_type == "identity" {
+ Ok(Self {
+ modules: seq().add(IdentityMap {}),
+ })
+ } else {
+ bail!(
+ "Unsupported MM projector type: {}",
+ config.mm_projector_type
+ )
+ }
+ }
+
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ self.modules.forward(x)
+ }
+}
+
+pub struct ClipVisionTower {
+ model: ClipVisionTransformer,
+ select_layer: isize,
+ select_feature_method: String,
+ pub config: ClipVisionConfig,
+}
+
+impl ClipVisionTower {
+ pub fn new(
+ vb: VarBuilder,
+ select_layer: isize,
+ select_feature_method: &str,
+ config: &Option<ClipVisionConfig>,
+ ) -> Result<Self> {
+ let config = if config.is_none() {
+ ClipVisionConfig::clip_vit_large_patch14_336()
+ } else {
+ config.clone().unwrap()
+ };
+ let select_layer = match select_layer {
+ -1 | -2 => select_layer,
+ _ => bail!("Unsupported select layer: {}", select_layer),
+ };
+ let model = ClipVisionTransformer::new(vb, &config)?;
+ Ok(Self {
+ model,
+ select_layer,
+ select_feature_method: select_feature_method.to_string(),
+ config,
+ })
+ }
+
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let result = self.model.output_hidden_states(x)?;
+ let index = result.len() as isize + self.select_layer;
+ let result = result[index as usize].clone();
+ if self.select_feature_method == "cls_patch" {
+ Ok(result)
+ } else {
+ result.i((.., 1..))
+ }
+ }
+
+ pub fn num_patches_per_side(&self) -> usize {
+ self.config.image_size / self.config.patch_size
+ }
+}
+
+pub struct LLaVA {
+ pub clip_vision_tower: ClipVisionTower,
+ pub image_newline: Tensor,
+ pub mm_projector: MMProjector,
+ pub llama: Llama,
+ config: LLaVAConfig,
+ device: Device,
+}
+
+impl LLaVA {
+ pub fn load(
+ vb: VarBuilder,
+ config: &LLaVAConfig,
+ clip_vision_config: Option<ClipVisionConfig>,
+ ) -> Result<Self> {
+ let device = vb.device().clone();
+ let llama_config = config.to_llama_config();
+ let mm_projector = MMProjector::load(&vb, config)?;
+ let (clip_vision_tower, image_newline, llama) = if config.hf {
+ (
+ ClipVisionTower::new(
+ vb.pp("vision_tower.vision_model"),
+ config.mm_vision_select_layer,
+ &config.mm_vision_select_feature,
+ &clip_vision_config,
+ )?,
+ vb.get(&[config.hidden_size], "image_newline")?
+ .to_device(&device)?,
+ Llama::load(vb.pp("language_model"), &llama_config)?,
+ )
+ } else {
+ (
+ ClipVisionTower::new(
+ vb.pp("model.vision_tower.vision_tower.vision_model"),
+ config.mm_vision_select_layer,
+ &config.mm_vision_select_feature,
+ &clip_vision_config,
+ )?,
+ vb.get(&[config.hidden_size], "model.image_newline")?
+ .to_device(&device)?,
+ Llama::load(vb, &llama_config)?,
+ )
+ };
+ Ok(Self {
+ clip_vision_tower,
+ image_newline,
+ mm_projector,
+ llama,
+ config: (*config).clone(),
+ device,
+ })
+ }
+
+ pub fn encode_images(&self, x: &Tensor) -> Result<Tensor> {
+ let image_features = self.clip_vision_tower.forward(x)?;
+ let image_features = self.mm_projector.forward(&image_features)?;
+ Ok(image_features)
+ }
+ // currently only for single image, 4 dim tensor
+ pub fn prepare_inputs_labels_for_multimodal(
+ &self,
+ input_ids: &Tensor,
+ images: &[Tensor],
+ image_sizes: &[(u32, u32)],
+ ) -> Result<Tensor> {
+ //TODO: process of multiple images/ new line
+ // 576: 336(input size)/14(patch size)=24 24*24+1(class)=577 577-1=576
+ let concat_images = Tensor::cat(images, 0)?;
+ let image_features_together = self.encode_images(&concat_images)?;
+ let split_sizes = images
+ .iter()
+ .map(|x| x.shape().dims()[0])
+ .collect::<Vec<usize>>();
+ // can be replaced by split
+ let mut index_pos = 0;
+ let mut image_features = Vec::new();
+ for split_size in split_sizes.iter() {
+ image_features.push(image_features_together.i(index_pos..index_pos + (*split_size))?);
+ index_pos += *split_size;
+ }
+ let mm_patch_merge_type = &self.config.mm_patch_merge_type;
+ let image_aspect_ratio = &self.config.image_aspect_ratio;
+
+ let image_features = if mm_patch_merge_type == "flat" {
+ image_features
+ .iter()
+ .map(|x| x.flatten(0, 1).unwrap())
+ .collect::<Vec<Tensor>>()
+ } else if mm_patch_merge_type.starts_with("spatial") {
+ let mut new_image_features = Vec::new();
+ for (image_idx, image_feature) in image_features.iter().enumerate() {
+ let new_image_feature = if image_feature.dims()[0] > 1 {
+ let base_image_feature = image_feature.get(0).unwrap();
+ let patch_image_feature = image_feature.i(1..).unwrap();
+ let height = self.clip_vision_tower.num_patches_per_side();
+ let width = height;
+ assert_eq!(height * width, base_image_feature.dims()[0]);
+ let image_size = image_sizes[image_idx];
+ let new_image_feature = if image_aspect_ratio == "anyres" {
+ let (num_patch_width, num_patch_height) = get_anyres_image_grid_shape(
+ image_size,
+ &self.config.image_grid_pinpoints,
+ self.clip_vision_tower.config.image_size as u32,
+ );
+ patch_image_feature.reshape((
+ num_patch_height as usize,
+ num_patch_width as usize,
+ height,
+ width,
+ (),
+ ))?
+ } else {
+ todo!("not implemented in original python LLaVA yet")
+ };
+ let new_image_feature = if mm_patch_merge_type.contains("unpad") {
+ let new_image_feature = new_image_feature
+ .permute((4, 0, 2, 1, 3))?
+ .flatten(1, 2)?
+ .flatten(2, 3)?;
+ let new_image_feature = unpad_image(&new_image_feature, &image_size)?;
+ let new_image_feature_dims = new_image_feature.dims();
+ let image_new_line = self
+ .image_newline
+ .reshape((self.config.hidden_size, 1, 1))?
+ .broadcast_as((
+ new_image_feature_dims[0],
+ new_image_feature_dims[1],
+ 1,
+ ))?;
+ let new_image_feature =
+ Tensor::cat(&[new_image_feature, image_new_line], 2)?;
+ new_image_feature.flatten(1, 2)?.transpose(0, 1)?
+ } else {
+ new_image_feature.permute((0, 2, 1, 3, 4))?.flatten(0, 3)?
+ };
+ Tensor::cat(&[base_image_feature, new_image_feature], 0)?
+ } else {
+ let new_image_feature = image_feature.get(0).unwrap();
+ if mm_patch_merge_type.contains("unpad") {
+ Tensor::cat(
+ &[
+ new_image_feature,
+ self.image_newline.clone().unsqueeze(0).unwrap(),
+ ],
+ 0,
+ )
+ .unwrap()
+ } else {
+ new_image_feature
+ }
+ };
+ new_image_features.push(new_image_feature);
+ }
+ new_image_features
+ } else {
+ bail!("Unexpected mm_patch_merge_type: {mm_patch_merge_type}")
+ };
+ // can easily be replaced by nonzero if it is implemented in candle
+ let input_ids_vec = input_ids.squeeze(0)?.to_vec1::<i64>()?;
+ let mut image_indices = {
+ let mut image_indices = vec![0_i64];
+ image_indices.extend(
+ input_ids_vec
+ .iter()
+ .enumerate()
+ .filter_map(|(i, x)| {
+ if *x == self.config.image_token_index as i64 {
+ Some(i as i64)
+ } else {
+ None
+ }
+ })
+ .collect::<Vec<i64>>(),
+ );
+ image_indices
+ };
+ if image_indices.len() == 1 {
+ //no image, only [0],
+ return self.llama.embed(input_ids);
+ }
+
+ let input_ids_noim = input_ids_vec
+ .iter()
+ .filter_map(|x| {
+ if *x != self.config.image_token_index as i64 {
+ Some(*x)
+ } else {
+ None
+ }
+ })
+ .collect::<Vec<i64>>();
+ let input_ids_noim_len = input_ids_noim.len();
+ image_indices.push((input_ids_noim_len) as i64);
+ let input_ids_noim = Tensor::from_vec(input_ids_noim, input_ids_noim_len, &self.device)?;
+ let cur_input_embeds = self.llama.embed(&input_ids_noim)?;
+ // can be replace by split if it is implemented in candle
+ let input_embed_no_ims = {
+ let mut input_embeds = Vec::new();
+ for i in 0..image_indices.len() - 1 {
+ let start = (image_indices[i]) as usize;
+ let end = image_indices[i + 1] as usize;
+ input_embeds.push(cur_input_embeds.i((start..end, ..))?)
+ }
+ input_embeds
+ };
+
+ let mut cur_new_input_embeds = Vec::new();
+ for (i, image_feature) in image_features.iter().enumerate() {
+ cur_new_input_embeds.push(input_embed_no_ims[i].clone());
+ cur_new_input_embeds.push(image_feature.clone());
+ }
+ cur_new_input_embeds.push(input_embed_no_ims[image_features.len()].clone());
+ let new_input_embeds = Tensor::cat(&cur_new_input_embeds, 0)?;
+ //trancate
+ let new_input_embeds =
+ if let Some(tokenizer_model_max_length) = self.config.tokenizer_model_max_length {
+ let (new_input_embeds_length, _) = new_input_embeds.shape().dims2()?;
+ if new_input_embeds_length > tokenizer_model_max_length {
+ new_input_embeds.i((..tokenizer_model_max_length, ..))?
+ } else {
+ new_input_embeds
+ }
+ } else {
+ new_input_embeds
+ };
+ new_input_embeds.unsqueeze(0)
+ }
+
+ pub fn forward(
+ &self,
+ input_embeds: &Tensor,
+ position_id: usize,
+ cache: &mut Cache,
+ ) -> Result<Tensor> {
+ self.llama
+ .forward_input_embed(input_embeds, position_id, cache)
+ }
+}
diff --git a/candle-transformers/src/models/llava/utils.rs b/candle-transformers/src/models/llava/utils.rs
new file mode 100644
index 00000000..3b4c18bb
--- /dev/null
+++ b/candle-transformers/src/models/llava/utils.rs
@@ -0,0 +1,41 @@
+pub fn get_anyres_image_grid_shape(
+ image_size: (u32, u32),
+ grid_pinpoints: &[(u32, u32)],
+ patch_size: u32,
+) -> (u32, u32) {
+ let (width, height) = select_best_resolution(image_size, grid_pinpoints);
+ (width / patch_size, height / patch_size)
+}
+
+pub fn select_best_resolution(
+ original_size: (u32, u32),
+ possible_resolutions: &[(u32, u32)],
+) -> (u32, u32) {
+ let (original_width, original_height) = original_size;
+ let mut best_fit = (0, 0);
+ let original_width_f = original_width as f32;
+ let original_height_f = original_height as f32;
+ let mut max_effective_resolution = 0_u32;
+ let mut min_wasted_resolution = u32::MAX;
+ for (width, height) in possible_resolutions {
+ let width_f = *width as f32;
+ let height_f = *height as f32;
+ let scale = (width_f / original_width_f).min(height_f / original_height_f);
+ let (downscaled_width, downscaled_height) = (
+ (original_width_f * scale) as u32,
+ (original_height_f * scale) as u32,
+ );
+ let effective_resolution =
+ std::cmp::min((*width) * (*height), downscaled_width * downscaled_height);
+ let wasted_resolution = (*width) * (*height) - effective_resolution;
+ if effective_resolution > max_effective_resolution
+ || (effective_resolution == max_effective_resolution
+ && wasted_resolution < min_wasted_resolution)
+ {
+ best_fit = (*width, *height);
+ max_effective_resolution = effective_resolution;
+ min_wasted_resolution = wasted_resolution;
+ }
+ }
+ best_fit
+}
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index de2430a2..4628a3de 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -17,6 +17,7 @@ pub mod jina_bert;
pub mod llama;
pub mod llama2_c;
pub mod llama2_c_weights;
+pub mod llava;
pub mod mamba;
pub mod marian;
pub mod metavoice;