summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/colpali.rs
diff options
context:
space:
mode:
authorAkshay Ballal <61191840+akshayballal95@users.noreply.github.com>2024-10-01 11:48:39 +0200
committerGitHub <noreply@github.com>2024-10-01 11:48:39 +0200
commit888d886dd8d5cac2558064060c59a4b51b6aa530 (patch)
tree7bf0848bc3211453b7e07b26edf5c108e45dc7cf /candle-transformers/src/models/colpali.rs
parent6110ad8d4ff8272bdd10687eae4edee59a07b517 (diff)
downloadcandle-888d886dd8d5cac2558064060c59a4b51b6aa530.tar.gz
candle-888d886dd8d5cac2558064060c59a4b51b6aa530.tar.bz2
candle-888d886dd8d5cac2558064060c59a4b51b6aa530.zip
Add ColPali (#2524)
* add colpali * cleanup * fix clippy
Diffstat (limited to 'candle-transformers/src/models/colpali.rs')
-rw-r--r--candle-transformers/src/models/colpali.rs42
1 files changed, 42 insertions, 0 deletions
diff --git a/candle-transformers/src/models/colpali.rs b/candle-transformers/src/models/colpali.rs
new file mode 100644
index 00000000..1299b0a4
--- /dev/null
+++ b/candle-transformers/src/models/colpali.rs
@@ -0,0 +1,42 @@
+use candle::{Module, Result, Tensor};
+use candle_nn::VarBuilder;
+
+use super::paligemma;
+use candle_nn::{linear, Linear};
+
+pub struct Model {
+ pub model: paligemma::Model,
+ pub custom_text_projection: Linear,
+}
+
+impl Model {
+ pub fn new(config: &paligemma::Config, vb: VarBuilder) -> Result<Self> {
+ let model = paligemma::Model::new(config, vb.pp("model"))?;
+ let custom_text_projection = linear(
+ config.text_config.hidden_size,
+ 128,
+ vb.pp("custom_text_proj"),
+ )?;
+
+ Ok(Self {
+ model,
+ custom_text_projection,
+ })
+ }
+
+ pub fn forward_images(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> {
+ let outputs = self
+ .model
+ .setup_without_projection(pixel_values, input_ids)?;
+ let outputs = self.custom_text_projection.forward(&outputs)?;
+ let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
+ Ok(outputs)
+ }
+
+ pub fn forward_text(&mut self, input_ids: &Tensor) -> Result<Tensor> {
+ let outputs = self.model.forward_without_projection(input_ids)?;
+ let outputs = self.custom_text_projection.forward(&outputs)?;
+ let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
+ Ok(outputs)
+ }
+}