summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/llava/mod.rs
blob: c252dbed5695b352d7d82181a2f5ae5ed7f57de1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
//! The LLaVA (Large Language and Vision Assistant) model.
//!
//! This provides the main model implementation combining a vision tower (CLIP) with
//! language model (Llama) for multimodal capabilities. The architecture implements the training-free projection technique.
//!
//! - 💻[GH Link](https://github.com/haotian-liu/LLaVA/tree/main)
//! - 📝 [Paper](https://arxiv.org/abs/2304.08485)/ Visual Instruction Tuning
//!

pub mod config;
pub mod utils;

use crate::models::clip::vision_model::{ClipVisionConfig, ClipVisionTransformer};
use crate::models::llama::{Cache, Llama};
use crate::models::with_tracing::linear;

use candle::{bail, Device, IndexOp, Result, Tensor};
use candle_nn::{seq, Activation, Module, Sequential, VarBuilder};
use fancy_regex::Regex;
use utils::get_anyres_image_grid_shape;

use config::LLaVAConfig;

fn mlp_gelu_match(mm_projector_type: &str) -> Option<usize> {
    let mlp_gelu_regex = Regex::new(r"^mlp(\d+)x_gelu$").unwrap();

    if let Ok(Some(captures)) = mlp_gelu_regex.captures(mm_projector_type) {
        if let Some(match_str) = captures.get(1) {
            let match_str = match_str.as_str();
            match_str.parse::<usize>().ok()
        } else {
            None
        }
    } else {
        None
    }
}

fn unpad_image(tensor: &Tensor, original_size: &(u32, u32)) -> Result<Tensor> {
    assert_eq!(tensor.dims().len(), 3);
    let (original_width, original_height) = *original_size;
    let tensor_dims = tensor.dims();
    let current_height = tensor_dims[1];
    let current_width = tensor_dims[2];
    let original_aspect_ratio = (original_width as f32) / (original_height as f32);
    let current_aspect_ratio = (current_width as f32) / (current_height as f32);
    if original_aspect_ratio > current_aspect_ratio {
        let scale_factor = (current_width as f32) / (original_width as f32);
        let new_height = (original_height as f32 * scale_factor).floor() as usize;
        let padding = (current_height - new_height) / 2;
        tensor.i((.., padding..current_width - padding, ..))
    } else {
        let scale_factor = (current_height as f32) / (original_height as f32);
        let new_width = (original_width as f32 * scale_factor).floor() as usize;
        let padding = (current_width - new_width) / 2;
        tensor.i((.., .., padding..current_width - padding))
    }
}

pub struct IdentityMap {}

impl Module for IdentityMap {
    fn forward(&self, x: &Tensor) -> Result<Tensor> {
        Ok(x.clone())
    }
}

pub struct MMProjector {
    pub modules: Sequential,
}

impl MMProjector {
    pub fn load(vb: &VarBuilder, config: &LLaVAConfig) -> Result<Self> {
        if config.mm_projector_type == "linear" {
            let vb_prefix = if config.hf {
                "multi_modal_projector.linear_1"
            } else {
                "model.mm_projector.0"
            };
            let linear = linear(config.mm_hidden_size, config.hidden_size, vb.pp(vb_prefix))?;
            let modules = seq().add(linear);
            Ok(Self { modules })
        } else if let Some(mlp_depth) = mlp_gelu_match(&config.mm_projector_type) {
            let modules = if config.hf {
                let mut modules = seq().add(linear(
                    config.mm_hidden_size,
                    config.hidden_size,
                    vb.pp("multi_modal_projector.linear_1"),
                )?);
                for i in 1..mlp_depth {
                    modules = modules.add(Activation::Gelu).add(linear(
                        config.hidden_size,
                        config.hidden_size,
                        vb.pp(format!("multi_modal_projector.linear_{}", i + 1)),
                    )?);
                }
                modules
            } else {
                let mut modules = seq().add(linear(
                    config.mm_hidden_size,
                    config.hidden_size,
                    vb.pp("model.mm_projector.0"),
                )?);
                for i in 1..mlp_depth {
                    modules = modules.add(Activation::Gelu).add(linear(
                        config.hidden_size,
                        config.hidden_size,
                        vb.pp(format!("model.mm_projector.{}", i * 2)),
                    )?);
                }
                modules
            };
            Ok(Self { modules })
        } else if config.mm_projector_type == "identity" {
            Ok(Self {
                modules: seq().add(IdentityMap {}),
            })
        } else {
            bail!(
                "Unsupported MM projector type: {}",
                config.mm_projector_type
            )
        }
    }

    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
        self.modules.forward(x)
    }
}

pub struct ClipVisionTower {
    model: ClipVisionTransformer,
    select_layer: isize,
    select_feature_method: String,
    pub config: ClipVisionConfig,
}

impl ClipVisionTower {
    pub fn new(
        vb: VarBuilder,
        select_layer: isize,
        select_feature_method: &str,
        config: &Option<ClipVisionConfig>,
    ) -> Result<Self> {
        let config = if config.is_none() {
            ClipVisionConfig::clip_vit_large_patch14_336()
        } else {
            config.clone().unwrap()
        };
        let select_layer = match select_layer {
            -1 | -2 => select_layer,
            _ => bail!("Unsupported select layer: {}", select_layer),
        };
        let model = ClipVisionTransformer::new(vb, &config)?;
        Ok(Self {
            model,
            select_layer,
            select_feature_method: select_feature_method.to_string(),
            config,
        })
    }

    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
        let result = self.model.output_hidden_states(x)?;
        let index = result.len() as isize + self.select_layer;
        let result = result[index as usize].clone();
        if self.select_feature_method == "cls_patch" {
            Ok(result)
        } else {
            result.i((.., 1..))
        }
    }

    pub fn num_patches_per_side(&self) -> usize {
        self.config.image_size / self.config.patch_size
    }
}

pub struct LLaVA {
    pub clip_vision_tower: ClipVisionTower,
    pub image_newline: Tensor,
    pub mm_projector: MMProjector,
    pub llama: Llama,
    config: LLaVAConfig,
    device: Device,
}

impl LLaVA {
    pub fn load(
        vb: VarBuilder,
        config: &LLaVAConfig,
        clip_vision_config: Option<ClipVisionConfig>,
    ) -> Result<Self> {
        let device = vb.device().clone();
        let llama_config = config.to_llama_config();
        let mm_projector = MMProjector::load(&vb, config)?;
        let (clip_vision_tower, image_newline, llama) = if config.hf {
            (
                ClipVisionTower::new(
                    vb.pp("vision_tower.vision_model"),
                    config.mm_vision_select_layer,
                    &config.mm_vision_select_feature,
                    &clip_vision_config,
                )?,
                vb.get(&[config.hidden_size], "image_newline")?
                    .to_device(&device)?,
                Llama::load(vb.pp("language_model"), &llama_config)?,
            )
        } else {
            (
                ClipVisionTower::new(
                    vb.pp("model.vision_tower.vision_tower.vision_model"),
                    config.mm_vision_select_layer,
                    &config.mm_vision_select_feature,
                    &clip_vision_config,
                )?,
                vb.get(&[config.hidden_size], "model.image_newline")?
                    .to_device(&device)?,
                Llama::load(vb, &llama_config)?,
            )
        };
        Ok(Self {
            clip_vision_tower,
            image_newline,
            mm_projector,
            llama,
            config: (*config).clone(),
            device,
        })
    }

    pub fn encode_images(&self, x: &Tensor) -> Result<Tensor> {
        let image_features = self.clip_vision_tower.forward(x)?;
        let image_features = self.mm_projector.forward(&image_features)?;
        Ok(image_features)
    }
    // currently only for single image, 4 dim tensor
    pub fn prepare_inputs_labels_for_multimodal(
        &self,
        input_ids: &Tensor,
        images: &[Tensor],
        image_sizes: &[(u32, u32)],
    ) -> Result<Tensor> {
        //TODO: process of multiple images/ new line
        // 576: 336(input size)/14(patch size)=24 24*24+1(class)=577 577-1=576
        let concat_images = Tensor::cat(images, 0)?;
        let image_features_together = self.encode_images(&concat_images)?;
        let split_sizes = images
            .iter()
            .map(|x| x.shape().dims()[0])
            .collect::<Vec<usize>>();
        // can be replaced by split
        let mut index_pos = 0;
        let mut image_features = Vec::new();
        for split_size in split_sizes.iter() {
            image_features.push(image_features_together.i(index_pos..index_pos + (*split_size))?);
            index_pos += *split_size;
        }
        let mm_patch_merge_type = &self.config.mm_patch_merge_type;
        let image_aspect_ratio = &self.config.image_aspect_ratio;

        let image_features = if mm_patch_merge_type == "flat" {
            image_features
                .iter()
                .map(|x| x.flatten(0, 1).unwrap())
                .collect::<Vec<Tensor>>()
        } else if mm_patch_merge_type.starts_with("spatial") {
            let mut new_image_features = Vec::new();
            for (image_idx, image_feature) in image_features.iter().enumerate() {
                let new_image_feature = if image_feature.dims()[0] > 1 {
                    let base_image_feature = image_feature.get(0).unwrap();
                    let patch_image_feature = image_feature.i(1..).unwrap();
                    let height = self.clip_vision_tower.num_patches_per_side();
                    let width = height;
                    assert_eq!(height * width, base_image_feature.dims()[0]);
                    let image_size = image_sizes[image_idx];
                    let new_image_feature = if image_aspect_ratio == "anyres" {
                        let (num_patch_width, num_patch_height) = get_anyres_image_grid_shape(
                            image_size,
                            &self.config.image_grid_pinpoints,
                            self.clip_vision_tower.config.image_size as u32,
                        );
                        patch_image_feature.reshape((
                            num_patch_height as usize,
                            num_patch_width as usize,
                            height,
                            width,
                            (),
                        ))?
                    } else {
                        bail!("not implemented in original python LLaVA yet")
                    };
                    let new_image_feature = if mm_patch_merge_type.contains("unpad") {
                        let new_image_feature = new_image_feature
                            .permute((4, 0, 2, 1, 3))?
                            .flatten(1, 2)?
                            .flatten(2, 3)?;
                        let new_image_feature = unpad_image(&new_image_feature, &image_size)?;
                        let new_image_feature_dims = new_image_feature.dims();
                        let image_new_line = self
                            .image_newline
                            .reshape((self.config.hidden_size, 1, 1))?
                            .broadcast_as((
                                new_image_feature_dims[0],
                                new_image_feature_dims[1],
                                1,
                            ))?;
                        let new_image_feature =
                            Tensor::cat(&[new_image_feature, image_new_line], 2)?;
                        new_image_feature.flatten(1, 2)?.transpose(0, 1)?
                    } else {
                        new_image_feature.permute((0, 2, 1, 3, 4))?.flatten(0, 3)?
                    };
                    Tensor::cat(&[base_image_feature, new_image_feature], 0)?
                } else {
                    let new_image_feature = image_feature.get(0).unwrap();
                    if mm_patch_merge_type.contains("unpad") {
                        Tensor::cat(
                            &[
                                new_image_feature,
                                self.image_newline.clone().unsqueeze(0).unwrap(),
                            ],
                            0,
                        )
                        .unwrap()
                    } else {
                        new_image_feature
                    }
                };
                new_image_features.push(new_image_feature);
            }
            new_image_features
        } else {
            bail!("Unexpected mm_patch_merge_type: {mm_patch_merge_type}")
        };
        // can easily be replaced by nonzero if it is implemented in candle
        let input_ids_vec = input_ids.squeeze(0)?.to_vec1::<i64>()?;
        let mut image_indices = {
            let mut image_indices = vec![0_i64];
            image_indices.extend(
                input_ids_vec
                    .iter()
                    .enumerate()
                    .filter_map(|(i, x)| {
                        if *x == self.config.image_token_index as i64 {
                            Some(i as i64)
                        } else {
                            None
                        }
                    })
                    .collect::<Vec<i64>>(),
            );
            image_indices
        };
        if image_indices.len() == 1 {
            //no image, only [0],
            return self.llama.embed(input_ids);
        }

        let input_ids_noim = input_ids_vec
            .iter()
            .filter_map(|x| {
                if *x != self.config.image_token_index as i64 {
                    Some(*x)
                } else {
                    None
                }
            })
            .collect::<Vec<i64>>();
        let input_ids_noim_len = input_ids_noim.len();
        image_indices.push((input_ids_noim_len) as i64);
        let input_ids_noim = Tensor::from_vec(input_ids_noim, input_ids_noim_len, &self.device)?;
        let cur_input_embeds = self.llama.embed(&input_ids_noim)?;
        // can be replace by split if it is implemented in candle
        let input_embed_no_ims = {
            let mut input_embeds = Vec::new();
            for i in 0..image_indices.len() - 1 {
                let start = (image_indices[i]) as usize;
                let end = image_indices[i + 1] as usize;
                input_embeds.push(cur_input_embeds.i((start..end, ..))?)
            }
            input_embeds
        };

        let mut cur_new_input_embeds = Vec::new();
        for (i, image_feature) in image_features.iter().enumerate() {
            cur_new_input_embeds.push(input_embed_no_ims[i].clone());
            cur_new_input_embeds.push(image_feature.clone());
        }
        cur_new_input_embeds.push(input_embed_no_ims[image_features.len()].clone());
        let new_input_embeds = Tensor::cat(&cur_new_input_embeds, 0)?;
        //trancate
        let new_input_embeds =
            if let Some(tokenizer_model_max_length) = self.config.tokenizer_model_max_length {
                let (new_input_embeds_length, _) = new_input_embeds.shape().dims2()?;
                if new_input_embeds_length > tokenizer_model_max_length {
                    new_input_embeds.i((..tokenizer_model_max_length, ..))?
                } else {
                    new_input_embeds
                }
            } else {
                new_input_embeds
            };
        new_input_embeds.unsqueeze(0)
    }

    pub fn forward(
        &self,
        input_embeds: &Tensor,
        position_id: usize,
        cache: &mut Cache,
    ) -> Result<Tensor> {
        self.llama
            .forward_input_embed(input_embeds, position_id, cache)
    }
}