diff options
author | Zheng Li <875543533@qq.com> | 2024-08-01 14:26:19 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-01 08:26:19 +0200 |
commit | 4a52aeb4372ae42d469b0c477be12d7d79c28bf5 (patch) | |
tree | 42f17fd26419f6b45b6df82bfb68797e3a07f0f9 /candle-transformers | |
parent | 24d54d0ff90ecc701f4a41770482a2611da05d15 (diff) | |
download | candle-4a52aeb4372ae42d469b0c477be12d7d79c28bf5.tar.gz candle-4a52aeb4372ae42d469b0c477be12d7d79c28bf5.tar.bz2 candle-4a52aeb4372ae42d469b0c477be12d7d79c28bf5.zip |
bert attention mask (#1934)
* bert attention mask
* Allow for using None as a mask.
* Revert part of the changes so that the proper default mask applies.
* Cosmetic change.
* Another cosmetic tweak.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/bert.rs | 49 |
1 files changed, 32 insertions, 17 deletions
diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index 810f2803..42486a2d 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -230,10 +230,8 @@ impl BertSelfAttention { let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?; xs.contiguous() } -} -impl Module for BertSelfAttention { - fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { + fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> { let _enter = self.span.enter(); let query_layer = self.query.forward(hidden_states)?; let key_layer = self.key.forward(hidden_states)?; @@ -245,6 +243,7 @@ impl Module for BertSelfAttention { let attention_scores = query_layer.matmul(&key_layer.t()?)?; let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?; + let attention_scores = attention_scores.broadcast_add(attention_mask)?; let attention_probs = { let _enter_sm = self.span_softmax.enter(); candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)? @@ -307,12 +306,10 @@ impl BertAttention { span: tracing::span!(tracing::Level::TRACE, "attn"), }) } -} -impl Module for BertAttention { - fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { + fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> { let _enter = self.span.enter(); - let self_outputs = self.self_attention.forward(hidden_states)?; + let self_outputs = self.self_attention.forward(hidden_states, attention_mask)?; let attention_output = self.self_output.forward(&self_outputs, hidden_states)?; Ok(attention_output) } @@ -398,12 +395,10 @@ impl BertLayer { span: tracing::span!(tracing::Level::TRACE, "layer"), }) } -} -impl Module for BertLayer { - fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { + fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> { let _enter = self.span.enter(); - let attention_output = self.attention.forward(hidden_states)?; + let attention_output = self.attention.forward(hidden_states, attention_mask)?; // TODO: Support cross-attention? // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523 // TODO: Support something similar to `apply_chunking_to_forward`? @@ -429,15 +424,13 @@ impl BertEncoder { let span = tracing::span!(tracing::Level::TRACE, "encoder"); Ok(BertEncoder { layers, span }) } -} -impl Module for BertEncoder { - fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { + fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> { let _enter = self.span.enter(); let mut hidden_states = hidden_states.clone(); // Use a loop rather than a fold as it's easier to modify when adding debug/... for layer in self.layers.iter() { - hidden_states = layer.forward(&hidden_states)? + hidden_states = layer.forward(&hidden_states, attention_mask)? } Ok(hidden_states) } @@ -481,10 +474,32 @@ impl BertModel { }) } - pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> { + pub fn forward( + &self, + input_ids: &Tensor, + token_type_ids: &Tensor, + attention_mask: Option<&Tensor>, + ) -> Result<Tensor> { let _enter = self.span.enter(); let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?; - let sequence_output = self.encoder.forward(&embedding_output)?; + let attention_mask = match attention_mask { + Some(attention_mask) => attention_mask.clone(), + None => input_ids.ones_like()?, + }; + // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995 + let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?; + let sequence_output = self.encoder.forward(&embedding_output, &attention_mask)?; Ok(sequence_output) } } + +fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<Tensor> { + let attention_mask = match attention_mask.rank() { + 3 => attention_mask.unsqueeze(1)?, + 2 => attention_mask.unsqueeze(1)?.unsqueeze(1)?, + _ => candle::bail!("Wrong shape for input_ids or attention_mask"), + }; + let attention_mask = attention_mask.to_dtype(dtype)?; + // torch.finfo(dtype).min + (attention_mask.ones_like()? - attention_mask)?.broadcast_mul(&Tensor::try_from(f32::MIN)?) +} |