diff options
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/pixtral/llava.rs | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/candle-transformers/src/models/pixtral/llava.rs b/candle-transformers/src/models/pixtral/llava.rs index 33e0aca0..4aff26a7 100644 --- a/candle-transformers/src/models/pixtral/llava.rs +++ b/candle-transformers/src/models/pixtral/llava.rs @@ -48,6 +48,7 @@ pub struct Model { pub vision_tower: vision_model::Model, pub patch_size: usize, pub dtype: candle::DType, + pub pos: usize, } impl Model { @@ -67,6 +68,31 @@ impl Model { vision_tower, patch_size: cfg.vision_config.patch_size, dtype: vb.dtype(), + pos: 0, }) } + + pub fn clear_kv_cache(&mut self) { + self.language_model.clear_kv_cache(); + self.pos = 0; + } + + pub fn encode_image(&self, image: &Tensor) -> Result<Tensor> { + let image_embeds = self.vision_tower.forward(image)?; + self.multi_modal_projector.forward(&image_embeds) + } + + pub fn lm_forward(&mut self, input_ids: &Tensor) -> Result<Tensor> { + let (_, seq_len) = input_ids.dims2()?; + let logits = self.language_model.forward(input_ids, self.pos)?; + self.pos += seq_len; + Ok(logits) + } + + pub fn lm_forward_embeds(&mut self, xs: &Tensor) -> Result<Tensor> { + let (_, seq_len, _) = xs.dims3()?; + let logits = self.language_model.forward_embeds(xs, None, self.pos)?; + self.pos += seq_len; + Ok(logits) + } } |