summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-09-30 21:23:54 +0200
committerGitHub <noreply@github.com>2024-09-30 21:23:54 +0200
commitdfe9a006834938a7d4dde6a6e3b81ed6e595bf99 (patch)
tree72a78f1f24cc3e362a657198b7cca1c90d738c51 /candle-transformers
parent683ab698def755c24cec9987069d25efcf831fc4 (diff)
downloadcandle-dfe9a006834938a7d4dde6a6e3b81ed6e595bf99.tar.gz
candle-dfe9a006834938a7d4dde6a6e3b81ed6e595bf99.tar.bz2
candle-dfe9a006834938a7d4dde6a6e3b81ed6e595bf99.zip
Pixtral polishing. (#2522)
* Pixtral polishing. * Clippy fix.
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/pixtral/llava.rs26
1 files changed, 26 insertions, 0 deletions
diff --git a/candle-transformers/src/models/pixtral/llava.rs b/candle-transformers/src/models/pixtral/llava.rs
index 33e0aca0..4aff26a7 100644
--- a/candle-transformers/src/models/pixtral/llava.rs
+++ b/candle-transformers/src/models/pixtral/llava.rs
@@ -48,6 +48,7 @@ pub struct Model {
pub vision_tower: vision_model::Model,
pub patch_size: usize,
pub dtype: candle::DType,
+ pub pos: usize,
}
impl Model {
@@ -67,6 +68,31 @@ impl Model {
vision_tower,
patch_size: cfg.vision_config.patch_size,
dtype: vb.dtype(),
+ pos: 0,
})
}
+
+ pub fn clear_kv_cache(&mut self) {
+ self.language_model.clear_kv_cache();
+ self.pos = 0;
+ }
+
+ pub fn encode_image(&self, image: &Tensor) -> Result<Tensor> {
+ let image_embeds = self.vision_tower.forward(image)?;
+ self.multi_modal_projector.forward(&image_embeds)
+ }
+
+ pub fn lm_forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
+ let (_, seq_len) = input_ids.dims2()?;
+ let logits = self.language_model.forward(input_ids, self.pos)?;
+ self.pos += seq_len;
+ Ok(logits)
+ }
+
+ pub fn lm_forward_embeds(&mut self, xs: &Tensor) -> Result<Tensor> {
+ let (_, seq_len, _) = xs.dims3()?;
+ let logits = self.language_model.forward_embeds(xs, None, self.pos)?;
+ self.pos += seq_len;
+ Ok(logits)
+ }
}