1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
|
use candle::{Module, Result, Tensor};
use candle_nn::{linear, Linear, VarBuilder};
use super::vision_model;
use crate::models::mistral;
#[derive(serde::Deserialize, Debug, Clone)]
pub struct Config {
pub projector_hidden_act: candle_nn::Activation,
pub text_config: mistral::Config,
pub vision_config: vision_model::Config,
pub image_token_index: usize,
pub image_seq_length: usize,
}
#[derive(Debug, Clone)]
pub struct MultiModalProjector {
linear_1: Linear,
act: candle_nn::Activation,
linear_2: Linear,
}
impl MultiModalProjector {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let (hidden_v, hidden_t) = (cfg.vision_config.hidden_size, cfg.text_config.hidden_size);
let linear_1 = linear(hidden_v, hidden_t, vb.pp("linear_1"))?;
let linear_2 = linear(hidden_t, hidden_t, vb.pp("linear_2"))?;
Ok(Self {
linear_1,
act: cfg.projector_hidden_act,
linear_2,
})
}
}
impl Module for MultiModalProjector {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.linear_1)?
.apply(&self.act)?
.apply(&self.linear_2)
}
}
#[derive(Debug, Clone)]
pub struct Model {
pub multi_modal_projector: MultiModalProjector,
pub language_model: mistral::Model,
pub vision_tower: vision_model::Model,
pub patch_size: usize,
pub dtype: candle::DType,
pub pos: usize,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let language_model = mistral::Model::new(&cfg.text_config, vb.pp("language_model"))?;
let vision_tower = vision_model::Model::new(
&cfg.vision_config,
vb.pp("vision_tower").to_dtype(candle::DType::F32),
)?;
let multi_modal_projector = MultiModalProjector::new(
cfg,
vb.pp("multi_modal_projector").to_dtype(candle::DType::F32),
)?;
Ok(Self {
multi_modal_projector,
language_model,
vision_tower,
patch_size: cfg.vision_config.patch_size,
dtype: vb.dtype(),
pos: 0,
})
}
pub fn clear_kv_cache(&mut self) {
self.language_model.clear_kv_cache();
self.pos = 0;
}
pub fn encode_image(&self, image: &Tensor) -> Result<Tensor> {
let image_embeds = self.vision_tower.forward(image)?;
self.multi_modal_projector.forward(&image_embeds)
}
pub fn lm_forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let (_, seq_len) = input_ids.dims2()?;
let logits = self.language_model.forward(input_ids, self.pos)?;
self.pos += seq_len;
Ok(logits)
}
pub fn lm_forward_embeds(&mut self, xs: &Tensor) -> Result<Tensor> {
let (_, seq_len, _) = xs.dims3()?;
let logits = self.language_model.forward_embeds(xs, None, self.pos)?;
self.pos += seq_len;
Ok(logits)
}
}
|