diff options
Diffstat (limited to 'candle-transformers/src/models/gemma.rs')
-rw-r--r-- | candle-transformers/src/models/gemma.rs | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs index 69e22678..c22a3948 100644 --- a/candle-transformers/src/models/gemma.rs +++ b/candle-transformers/src/models/gemma.rs @@ -403,7 +403,6 @@ impl Model { .apply(&self.norm)? .apply(&self.lm_head) } - pub fn forward_embeds( &mut self, xs: &Tensor, @@ -420,6 +419,21 @@ impl Model { .apply(&self.lm_head) } + // Forward the model and return the hidden states without the lm_head + pub fn forward_embeds_without_projection( + &mut self, + xs: &Tensor, + attn_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result<Tensor> { + let (_, _, _) = xs.dims3()?; + let mut xs = (xs * (self.hidden_size as f64).sqrt())?; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attn_mask, seqlen_offset)? + } + Ok(xs) + } + pub fn clear_kv_cache(&mut self) { for layer in self.layers.iter_mut() { layer.clear_kv_cache() |