summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/gemma.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/gemma.rs')
-rw-r--r--candle-transformers/src/models/gemma.rs16
1 files changed, 15 insertions, 1 deletions
diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs
index 69e22678..c22a3948 100644
--- a/candle-transformers/src/models/gemma.rs
+++ b/candle-transformers/src/models/gemma.rs
@@ -403,7 +403,6 @@ impl Model {
.apply(&self.norm)?
.apply(&self.lm_head)
}
-
pub fn forward_embeds(
&mut self,
xs: &Tensor,
@@ -420,6 +419,21 @@ impl Model {
.apply(&self.lm_head)
}
+ // Forward the model and return the hidden states without the lm_head
+ pub fn forward_embeds_without_projection(
+ &mut self,
+ xs: &Tensor,
+ attn_mask: Option<&Tensor>,
+ seqlen_offset: usize,
+ ) -> Result<Tensor> {
+ let (_, _, _) = xs.dims3()?;
+ let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
+ for layer in self.layers.iter_mut() {
+ xs = layer.forward(&xs, attn_mask, seqlen_offset)?
+ }
+ Ok(xs)
+ }
+
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()