diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-03-23 13:08:53 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-23 13:08:53 +0100 |
commit | 6f877592a7d5b5023462e0b8d241a2ba5ad83648 (patch) | |
tree | 67d0b787d25295bd67846d3425eb76df9aa60abb /candle-transformers | |
parent | cc856db9ce2541e09731165f88cdd7aae37f558e (diff) | |
download | candle-6f877592a7d5b5023462e0b8d241a2ba5ad83648.tar.gz candle-6f877592a7d5b5023462e0b8d241a2ba5ad83648.tar.bz2 candle-6f877592a7d5b5023462e0b8d241a2ba5ad83648.zip |
Avoid broadcasting on the batch dimension for the attention mask. (#1920)
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/mistral.rs | 7 | ||||
-rw-r--r-- | candle-transformers/src/models/quantized_mistral.rs | 7 |
2 files changed, 6 insertions, 8 deletions
diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index be84f824..e40ae3ad 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -385,7 +385,6 @@ impl Model { fn prepare_decoder_attention_mask( &self, - b_size: usize, tgt_len: usize, seqlen_offset: usize, ) -> Result<Tensor> { @@ -408,16 +407,16 @@ impl Model { } else { mask }; - mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? + mask.expand((1, 1, tgt_len, tgt_len + seqlen_offset))? .to_dtype(self.dtype) } pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> { - let (b_size, seq_len) = input_ids.dims2()?; + let (_b_size, seq_len) = input_ids.dims2()?; let attention_mask = if seq_len <= 1 { None } else { - let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?; + let mask = self.prepare_decoder_attention_mask(seq_len, seqlen_offset)?; Some(mask) }; let mut xs = self.embed_tokens.forward(input_ids)?; diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs index 77de7b75..5f026f2b 100644 --- a/candle-transformers/src/models/quantized_mistral.rs +++ b/candle-transformers/src/models/quantized_mistral.rs @@ -287,7 +287,6 @@ impl Model { fn prepare_decoder_attention_mask( &self, - b_size: usize, tgt_len: usize, seqlen_offset: usize, ) -> Result<Tensor> { @@ -310,16 +309,16 @@ impl Model { } else { mask }; - mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? + mask.expand((1, 1, tgt_len, tgt_len + seqlen_offset))? .to_dtype(DType::F32) } pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> { - let (b_size, seq_len) = input_ids.dims2()?; + let (_b_size, seq_len) = input_ids.dims2()?; let attention_mask = if seq_len <= 1 { None } else { - let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?; + let mask = self.prepare_decoder_attention_mask(seq_len, seqlen_offset)?; Some(mask) }; let mut xs = self.embed_tokens.forward(input_ids)?; |