summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/colpali.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/colpali.rs')
-rw-r--r--candle-transformers/src/models/colpali.rs42
1 files changed, 42 insertions, 0 deletions
diff --git a/candle-transformers/src/models/colpali.rs b/candle-transformers/src/models/colpali.rs
new file mode 100644
index 00000000..1299b0a4
--- /dev/null
+++ b/candle-transformers/src/models/colpali.rs
@@ -0,0 +1,42 @@
+use candle::{Module, Result, Tensor};
+use candle_nn::VarBuilder;
+
+use super::paligemma;
+use candle_nn::{linear, Linear};
+
+pub struct Model {
+ pub model: paligemma::Model,
+ pub custom_text_projection: Linear,
+}
+
+impl Model {
+ pub fn new(config: &paligemma::Config, vb: VarBuilder) -> Result<Self> {
+ let model = paligemma::Model::new(config, vb.pp("model"))?;
+ let custom_text_projection = linear(
+ config.text_config.hidden_size,
+ 128,
+ vb.pp("custom_text_proj"),
+ )?;
+
+ Ok(Self {
+ model,
+ custom_text_projection,
+ })
+ }
+
+ pub fn forward_images(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> {
+ let outputs = self
+ .model
+ .setup_without_projection(pixel_values, input_ids)?;
+ let outputs = self.custom_text_projection.forward(&outputs)?;
+ let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
+ Ok(outputs)
+ }
+
+ pub fn forward_text(&mut self, input_ids: &Tensor) -> Result<Tensor> {
+ let outputs = self.model.forward_without_projection(input_ids)?;
+ let outputs = self.custom_text_projection.forward(&outputs)?;
+ let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
+ Ok(outputs)
+ }
+}