summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/colpali.rs
blob: 1299b0a410dd5f31c65fb4aca31e52ebf951e95f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
use candle::{Module, Result, Tensor};
use candle_nn::VarBuilder;

use super::paligemma;
use candle_nn::{linear, Linear};

pub struct Model {
    pub model: paligemma::Model,
    pub custom_text_projection: Linear,
}

impl Model {
    pub fn new(config: &paligemma::Config, vb: VarBuilder) -> Result<Self> {
        let model = paligemma::Model::new(config, vb.pp("model"))?;
        let custom_text_projection = linear(
            config.text_config.hidden_size,
            128,
            vb.pp("custom_text_proj"),
        )?;

        Ok(Self {
            model,
            custom_text_projection,
        })
    }

    pub fn forward_images(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> {
        let outputs = self
            .model
            .setup_without_projection(pixel_values, input_ids)?;
        let outputs = self.custom_text_projection.forward(&outputs)?;
        let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
        Ok(outputs)
    }

    pub fn forward_text(&mut self, input_ids: &Tensor) -> Result<Tensor> {
        let outputs = self.model.forward_without_projection(input_ids)?;
        let outputs = self.custom_text_projection.forward(&outputs)?;
        let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
        Ok(outputs)
    }
}