diff options
author | Akshay Ballal <61191840+akshayballal95@users.noreply.github.com> | 2024-10-01 11:48:39 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-01 11:48:39 +0200 |
commit | 888d886dd8d5cac2558064060c59a4b51b6aa530 (patch) | |
tree | 7bf0848bc3211453b7e07b26edf5c108e45dc7cf /candle-transformers/src/models/colpali.rs | |
parent | 6110ad8d4ff8272bdd10687eae4edee59a07b517 (diff) | |
download | candle-888d886dd8d5cac2558064060c59a4b51b6aa530.tar.gz candle-888d886dd8d5cac2558064060c59a4b51b6aa530.tar.bz2 candle-888d886dd8d5cac2558064060c59a4b51b6aa530.zip |
Add ColPali (#2524)
* add colpali
* cleanup
* fix clippy
Diffstat (limited to 'candle-transformers/src/models/colpali.rs')
-rw-r--r-- | candle-transformers/src/models/colpali.rs | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/candle-transformers/src/models/colpali.rs b/candle-transformers/src/models/colpali.rs new file mode 100644 index 00000000..1299b0a4 --- /dev/null +++ b/candle-transformers/src/models/colpali.rs @@ -0,0 +1,42 @@ +use candle::{Module, Result, Tensor}; +use candle_nn::VarBuilder; + +use super::paligemma; +use candle_nn::{linear, Linear}; + +pub struct Model { + pub model: paligemma::Model, + pub custom_text_projection: Linear, +} + +impl Model { + pub fn new(config: &paligemma::Config, vb: VarBuilder) -> Result<Self> { + let model = paligemma::Model::new(config, vb.pp("model"))?; + let custom_text_projection = linear( + config.text_config.hidden_size, + 128, + vb.pp("custom_text_proj"), + )?; + + Ok(Self { + model, + custom_text_projection, + }) + } + + pub fn forward_images(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> { + let outputs = self + .model + .setup_without_projection(pixel_values, input_ids)?; + let outputs = self.custom_text_projection.forward(&outputs)?; + let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?; + Ok(outputs) + } + + pub fn forward_text(&mut self, input_ids: &Tensor) -> Result<Tensor> { + let outputs = self.model.forward_without_projection(input_ids)?; + let outputs = self.custom_text_projection.forward(&outputs)?; + let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?; + Ok(outputs) + } +} |