summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-23 13:08:53 +0100
committerGitHub <noreply@github.com>2024-03-23 13:08:53 +0100
commit6f877592a7d5b5023462e0b8d241a2ba5ad83648 (patch)
tree67d0b787d25295bd67846d3425eb76df9aa60abb /candle-transformers
parentcc856db9ce2541e09731165f88cdd7aae37f558e (diff)
downloadcandle-6f877592a7d5b5023462e0b8d241a2ba5ad83648.tar.gz
candle-6f877592a7d5b5023462e0b8d241a2ba5ad83648.tar.bz2
candle-6f877592a7d5b5023462e0b8d241a2ba5ad83648.zip
Avoid broadcasting on the batch dimension for the attention mask. (#1920)
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/mistral.rs7
-rw-r--r--candle-transformers/src/models/quantized_mistral.rs7
2 files changed, 6 insertions, 8 deletions
diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs
index be84f824..e40ae3ad 100644
--- a/candle-transformers/src/models/mistral.rs
+++ b/candle-transformers/src/models/mistral.rs
@@ -385,7 +385,6 @@ impl Model {
fn prepare_decoder_attention_mask(
&self,
- b_size: usize,
tgt_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
@@ -408,16 +407,16 @@ impl Model {
} else {
mask
};
- mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
+ mask.expand((1, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(self.dtype)
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
- let (b_size, seq_len) = input_ids.dims2()?;
+ let (_b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
- let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
+ let mask = self.prepare_decoder_attention_mask(seq_len, seqlen_offset)?;
Some(mask)
};
let mut xs = self.embed_tokens.forward(input_ids)?;
diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs
index 77de7b75..5f026f2b 100644
--- a/candle-transformers/src/models/quantized_mistral.rs
+++ b/candle-transformers/src/models/quantized_mistral.rs
@@ -287,7 +287,6 @@ impl Model {
fn prepare_decoder_attention_mask(
&self,
- b_size: usize,
tgt_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
@@ -310,16 +309,16 @@ impl Model {
} else {
mask
};
- mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
+ mask.expand((1, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(DType::F32)
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
- let (b_size, seq_len) = input_ids.dims2()?;
+ let (_b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
- let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
+ let mask = self.prepare_decoder_attention_mask(seq_len, seqlen_offset)?;
Some(mask)
};
let mut xs = self.embed_tokens.forward(input_ids)?;