summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/pixtral/llava.rs
blob: 4aff26a784cdade9adf1a9d7087fabe970da3f20 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
use candle::{Module, Result, Tensor};
use candle_nn::{linear, Linear, VarBuilder};

use super::vision_model;
use crate::models::mistral;

#[derive(serde::Deserialize, Debug, Clone)]
pub struct Config {
    pub projector_hidden_act: candle_nn::Activation,
    pub text_config: mistral::Config,
    pub vision_config: vision_model::Config,
    pub image_token_index: usize,
    pub image_seq_length: usize,
}

#[derive(Debug, Clone)]
pub struct MultiModalProjector {
    linear_1: Linear,
    act: candle_nn::Activation,
    linear_2: Linear,
}

impl MultiModalProjector {
    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
        let (hidden_v, hidden_t) = (cfg.vision_config.hidden_size, cfg.text_config.hidden_size);
        let linear_1 = linear(hidden_v, hidden_t, vb.pp("linear_1"))?;
        let linear_2 = linear(hidden_t, hidden_t, vb.pp("linear_2"))?;
        Ok(Self {
            linear_1,
            act: cfg.projector_hidden_act,
            linear_2,
        })
    }
}

impl Module for MultiModalProjector {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        xs.apply(&self.linear_1)?
            .apply(&self.act)?
            .apply(&self.linear_2)
    }
}

#[derive(Debug, Clone)]
pub struct Model {
    pub multi_modal_projector: MultiModalProjector,
    pub language_model: mistral::Model,
    pub vision_tower: vision_model::Model,
    pub patch_size: usize,
    pub dtype: candle::DType,
    pub pos: usize,
}

impl Model {
    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
        let language_model = mistral::Model::new(&cfg.text_config, vb.pp("language_model"))?;
        let vision_tower = vision_model::Model::new(
            &cfg.vision_config,
            vb.pp("vision_tower").to_dtype(candle::DType::F32),
        )?;
        let multi_modal_projector = MultiModalProjector::new(
            cfg,
            vb.pp("multi_modal_projector").to_dtype(candle::DType::F32),
        )?;
        Ok(Self {
            multi_modal_projector,
            language_model,
            vision_tower,
            patch_size: cfg.vision_config.patch_size,
            dtype: vb.dtype(),
            pos: 0,
        })
    }

    pub fn clear_kv_cache(&mut self) {
        self.language_model.clear_kv_cache();
        self.pos = 0;
    }

    pub fn encode_image(&self, image: &Tensor) -> Result<Tensor> {
        let image_embeds = self.vision_tower.forward(image)?;
        self.multi_modal_projector.forward(&image_embeds)
    }

    pub fn lm_forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
        let (_, seq_len) = input_ids.dims2()?;
        let logits = self.language_model.forward(input_ids, self.pos)?;
        self.pos += seq_len;
        Ok(logits)
    }

    pub fn lm_forward_embeds(&mut self, xs: &Tensor) -> Result<Tensor> {
        let (_, seq_len, _) = xs.dims3()?;
        let logits = self.language_model.forward_embeds(xs, None, self.pos)?;
        self.pos += seq_len;
        Ok(logits)
    }
}