diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-09-30 21:23:54 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-30 21:23:54 +0200 |
commit | dfe9a006834938a7d4dde6a6e3b81ed6e595bf99 (patch) | |
tree | 72a78f1f24cc3e362a657198b7cca1c90d738c51 /candle-transformers | |
parent | 683ab698def755c24cec9987069d25efcf831fc4 (diff) | |
download | candle-dfe9a006834938a7d4dde6a6e3b81ed6e595bf99.tar.gz candle-dfe9a006834938a7d4dde6a6e3b81ed6e595bf99.tar.bz2 candle-dfe9a006834938a7d4dde6a6e3b81ed6e595bf99.zip |
Pixtral polishing. (#2522)
* Pixtral polishing.
* Clippy fix.
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/pixtral/llava.rs | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/candle-transformers/src/models/pixtral/llava.rs b/candle-transformers/src/models/pixtral/llava.rs index 33e0aca0..4aff26a7 100644 --- a/candle-transformers/src/models/pixtral/llava.rs +++ b/candle-transformers/src/models/pixtral/llava.rs @@ -48,6 +48,7 @@ pub struct Model { pub vision_tower: vision_model::Model, pub patch_size: usize, pub dtype: candle::DType, + pub pos: usize, } impl Model { @@ -67,6 +68,31 @@ impl Model { vision_tower, patch_size: cfg.vision_config.patch_size, dtype: vb.dtype(), + pos: 0, }) } + + pub fn clear_kv_cache(&mut self) { + self.language_model.clear_kv_cache(); + self.pos = 0; + } + + pub fn encode_image(&self, image: &Tensor) -> Result<Tensor> { + let image_embeds = self.vision_tower.forward(image)?; + self.multi_modal_projector.forward(&image_embeds) + } + + pub fn lm_forward(&mut self, input_ids: &Tensor) -> Result<Tensor> { + let (_, seq_len) = input_ids.dims2()?; + let logits = self.language_model.forward(input_ids, self.pos)?; + self.pos += seq_len; + Ok(logits) + } + + pub fn lm_forward_embeds(&mut self, xs: &Tensor) -> Result<Tensor> { + let (_, seq_len, _) = xs.dims3()?; + let logits = self.language_model.forward_embeds(xs, None, self.pos)?; + self.pos += seq_len; + Ok(logits) + } } |