summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/blip_text.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-21 22:44:13 +0100
committerGitHub <noreply@github.com>2023-10-21 22:44:13 +0100
commit3115fe42e4b203b02219eaf85b749f6710d0de3e (patch)
tree2fc86982d579213ef53e6c61ea17f1f8773c6bd4 /candle-transformers/src/models/blip_text.rs
parent2531b13bf85a69058e8ed1b30c683d19d036df14 (diff)
downloadcandle-3115fe42e4b203b02219eaf85b749f6710d0de3e.tar.gz
candle-3115fe42e4b203b02219eaf85b749f6710d0de3e.tar.bz2
candle-3115fe42e4b203b02219eaf85b749f6710d0de3e.zip
Blip attention mask + readme (#1146)
* Add the attention mask to the blip model. * Add a readme.
Diffstat (limited to 'candle-transformers/src/models/blip_text.rs')
-rw-r--r--candle-transformers/src/models/blip_text.rs62
1 files changed, 49 insertions, 13 deletions
diff --git a/candle-transformers/src/models/blip_text.rs b/candle-transformers/src/models/blip_text.rs
index 8d0712c0..6db2b9d8 100644
--- a/candle-transformers/src/models/blip_text.rs
+++ b/candle-transformers/src/models/blip_text.rs
@@ -105,7 +105,12 @@ impl TextSelfAttention {
.permute((0, 2, 1, 3))
}
- fn forward(&self, xs: &Tensor, encoder_hidden_states: Option<&Tensor>) -> Result<Tensor> {
+ fn forward(
+ &self,
+ xs: &Tensor,
+ encoder_hidden_states: Option<&Tensor>,
+ attention_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
let query = self
.transpose_for_scores(&self.query.forward(xs)?)?
.contiguous()?;
@@ -127,6 +132,10 @@ impl TextSelfAttention {
let value = value.contiguous()?;
let attention_scores = query.matmul(&key.t()?)?;
let attention_scores = (attention_scores * self.attention_scale)?;
+ let attention_scores = match attention_mask {
+ Some(mask) => attention_scores.broadcast_add(mask)?,
+ None => attention_scores,
+ };
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
attention_probs
.matmul(&value)?
@@ -166,8 +175,15 @@ impl TextAttention {
Ok(Self { self_, output })
}
- fn forward(&self, xs: &Tensor, encoder_hidden_states: Option<&Tensor>) -> Result<Tensor> {
- let self_outputs = self.self_.forward(xs, encoder_hidden_states)?;
+ fn forward(
+ &self,
+ xs: &Tensor,
+ encoder_hidden_states: Option<&Tensor>,
+ attention_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let self_outputs = self
+ .self_
+ .forward(xs, encoder_hidden_states, attention_mask)?;
self.output.forward(&self_outputs, xs)
}
}
@@ -238,10 +254,15 @@ impl TextLayer {
})
}
- fn forward(&self, xs: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
- let attention_output = self.attention.forward(xs, None)?;
+ fn forward(
+ &self,
+ xs: &Tensor,
+ encoder_hidden_states: &Tensor,
+ attention_mask: &Tensor,
+ ) -> Result<Tensor> {
+ let attention_output = self.attention.forward(xs, None, Some(attention_mask))?;
let attention_output = match &self.cross_attention {
- Some(ca) => ca.forward(&attention_output, Some(encoder_hidden_states))?,
+ Some(ca) => ca.forward(&attention_output, Some(encoder_hidden_states), None)?,
None => candle::bail!("expected some cross-attn"),
};
let intermediate_output = self.intermediate.forward(&attention_output)?;
@@ -265,10 +286,15 @@ impl TextEncoder {
Ok(Self { layers })
}
- fn forward(&self, xs: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
+ fn forward(
+ &self,
+ xs: &Tensor,
+ encoder_hidden_states: &Tensor,
+ attention_mask: &Tensor,
+ ) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
- xs = layer.forward(&xs, encoder_hidden_states)?
+ xs = layer.forward(&xs, encoder_hidden_states, attention_mask)?
}
Ok(xs)
}
@@ -384,11 +410,16 @@ impl TextModel {
})
}
- fn forward(&self, input_ids: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
+ fn forward(
+ &self,
+ input_ids: &Tensor,
+ encoder_hidden_states: &Tensor,
+ attention_mask: &Tensor,
+ ) -> Result<Tensor> {
let embedding_output = self.embeddings.forward(input_ids)?;
- let sequence_output = self
- .encoder
- .forward(&embedding_output, encoder_hidden_states)?;
+ let sequence_output =
+ self.encoder
+ .forward(&embedding_output, encoder_hidden_states, attention_mask)?;
// We're interested in the sequence-output rather than the pooled-output.
Ok(sequence_output)
}
@@ -408,7 +439,12 @@ impl TextLMHeadModel {
}
pub fn forward(&self, input_ids: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
- let sequence_output = self.bert.forward(input_ids, encoder_hidden_states)?;
+ let seq_len = input_ids.dim(1)?;
+ let mask: Vec<_> = (0..seq_len)
+ .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
+ .collect();
+ let mask = Tensor::from_vec(mask, (seq_len, seq_len), input_ids.device())?;
+ let sequence_output = self.bert.forward(input_ids, encoder_hidden_states, &mask)?;
let prediction_scores = self.cls.forward(&sequence_output)?;
// return_logits is false so we don't discard the last sequence element.
Ok(prediction_scores)