summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorZheng Li <875543533@qq.com>2024-08-01 14:26:19 +0800
committerGitHub <noreply@github.com>2024-08-01 08:26:19 +0200
commit4a52aeb4372ae42d469b0c477be12d7d79c28bf5 (patch)
tree42f17fd26419f6b45b6df82bfb68797e3a07f0f9 /candle-transformers
parent24d54d0ff90ecc701f4a41770482a2611da05d15 (diff)
downloadcandle-4a52aeb4372ae42d469b0c477be12d7d79c28bf5.tar.gz
candle-4a52aeb4372ae42d469b0c477be12d7d79c28bf5.tar.bz2
candle-4a52aeb4372ae42d469b0c477be12d7d79c28bf5.zip
bert attention mask (#1934)
* bert attention mask * Allow for using None as a mask. * Revert part of the changes so that the proper default mask applies. * Cosmetic change. * Another cosmetic tweak. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/bert.rs49
1 files changed, 32 insertions, 17 deletions
diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs
index 810f2803..42486a2d 100644
--- a/candle-transformers/src/models/bert.rs
+++ b/candle-transformers/src/models/bert.rs
@@ -230,10 +230,8 @@ impl BertSelfAttention {
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
xs.contiguous()
}
-}
-impl Module for BertSelfAttention {
- fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+ fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let query_layer = self.query.forward(hidden_states)?;
let key_layer = self.key.forward(hidden_states)?;
@@ -245,6 +243,7 @@ impl Module for BertSelfAttention {
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
+ let attention_scores = attention_scores.broadcast_add(attention_mask)?;
let attention_probs = {
let _enter_sm = self.span_softmax.enter();
candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)?
@@ -307,12 +306,10 @@ impl BertAttention {
span: tracing::span!(tracing::Level::TRACE, "attn"),
})
}
-}
-impl Module for BertAttention {
- fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+ fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
- let self_outputs = self.self_attention.forward(hidden_states)?;
+ let self_outputs = self.self_attention.forward(hidden_states, attention_mask)?;
let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
Ok(attention_output)
}
@@ -398,12 +395,10 @@ impl BertLayer {
span: tracing::span!(tracing::Level::TRACE, "layer"),
})
}
-}
-impl Module for BertLayer {
- fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+ fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
- let attention_output = self.attention.forward(hidden_states)?;
+ let attention_output = self.attention.forward(hidden_states, attention_mask)?;
// TODO: Support cross-attention?
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
// TODO: Support something similar to `apply_chunking_to_forward`?
@@ -429,15 +424,13 @@ impl BertEncoder {
let span = tracing::span!(tracing::Level::TRACE, "encoder");
Ok(BertEncoder { layers, span })
}
-}
-impl Module for BertEncoder {
- fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+ fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut hidden_states = hidden_states.clone();
// Use a loop rather than a fold as it's easier to modify when adding debug/...
for layer in self.layers.iter() {
- hidden_states = layer.forward(&hidden_states)?
+ hidden_states = layer.forward(&hidden_states, attention_mask)?
}
Ok(hidden_states)
}
@@ -481,10 +474,32 @@ impl BertModel {
})
}
- pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
+ pub fn forward(
+ &self,
+ input_ids: &Tensor,
+ token_type_ids: &Tensor,
+ attention_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
let _enter = self.span.enter();
let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
- let sequence_output = self.encoder.forward(&embedding_output)?;
+ let attention_mask = match attention_mask {
+ Some(attention_mask) => attention_mask.clone(),
+ None => input_ids.ones_like()?,
+ };
+ // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
+ let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
+ let sequence_output = self.encoder.forward(&embedding_output, &attention_mask)?;
Ok(sequence_output)
}
}
+
+fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<Tensor> {
+ let attention_mask = match attention_mask.rank() {
+ 3 => attention_mask.unsqueeze(1)?,
+ 2 => attention_mask.unsqueeze(1)?.unsqueeze(1)?,
+ _ => candle::bail!("Wrong shape for input_ids or attention_mask"),
+ };
+ let attention_mask = attention_mask.to_dtype(dtype)?;
+ // torch.finfo(dtype).min
+ (attention_mask.ones_like()? - attention_mask)?.broadcast_mul(&Tensor::try_from(f32::MIN)?)
+}