summaryrefslogtreecommitdiff
path: root/candle-pyo3
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-13 19:53:40 +0200
committerGitHub <noreply@github.com>2023-10-13 18:53:40 +0100
commit75989fc3b7ad06f6216b3aab62a2f3a7fcf4ebba (patch)
treeca9f0e48ae9d08fd9f61dfedf1cf0feb96f246b3 /candle-pyo3
parent07af87a1d801852645966d89bd193808ff7c5b35 (diff)
downloadcandle-75989fc3b7ad06f6216b3aab62a2f3a7fcf4ebba.tar.gz
candle-75989fc3b7ad06f6216b3aab62a2f3a7fcf4ebba.tar.bz2
candle-75989fc3b7ad06f6216b3aab62a2f3a7fcf4ebba.zip
Use an attention mask in the e5 padding case. (#1085)
Diffstat (limited to 'candle-pyo3')
-rw-r--r--candle-pyo3/e5.py3
-rw-r--r--candle-pyo3/py_src/candle/models/bert.py34
2 files changed, 26 insertions, 11 deletions
diff --git a/candle-pyo3/e5.py b/candle-pyo3/e5.py
index 8ca48219..9b54ebcb 100644
--- a/candle-pyo3/e5.py
+++ b/candle-pyo3/e5.py
@@ -50,7 +50,8 @@ if __name__ == "__main__":
tokenized = tokenizer(sentences, padding=True)
tokens = Tensor(tokenized["input_ids"])
token_type_ids = Tensor(tokenized["token_type_ids"])
- encoder_out, _ = model.forward(tokens, token_type_ids)
+ attention_mask = Tensor(tokenized["attention_mask"])
+ encoder_out, _ = model.forward(tokens, token_type_ids, attention_mask=attention_mask)
hf_tokenized = tokenizer(sentences, padding=True, return_tensors="pt")
hf_result = hf_model(**hf_tokenized)["last_hidden_state"]
diff --git a/candle-pyo3/py_src/candle/models/bert.py b/candle-pyo3/py_src/candle/models/bert.py
index 0a773f93..36e242ad 100644
--- a/candle-pyo3/py_src/candle/models/bert.py
+++ b/candle-pyo3/py_src/candle/models/bert.py
@@ -46,7 +46,7 @@ class BertSelfAttention(Module):
x = x.reshape(new_x_shape).transpose(1, 2)
return x.contiguous()
- def forward(self, hidden_states: Tensor) -> Tensor:
+ def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
query = self.query.forward(hidden_states)
key = self.key.forward(hidden_states)
value = self.value.forward(hidden_states)
@@ -56,7 +56,11 @@ class BertSelfAttention(Module):
value = self.transpose_for_scores(value)
attention_scores = query.matmul(key.t())
- attention_scores = attention_scores / (float(self.attention_head_size) ** 0.5)
+ attention_scores = attention_scores / float(self.attention_head_size) ** 0.5
+ if attention_mask is not None:
+ b_size, _, _, last_dim = attention_scores.shape
+ attention_scores = attention_scores.broadcast_add(
+ attention_mask.reshape((b_size, 1, 1, last_dim)))
attention_probs = F.softmax(attention_scores, dim=-1)
context_layer = attention_probs.matmul(value)
@@ -82,8 +86,8 @@ class BertAttention(Module):
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
- def forward(self, hidden_states: Tensor) -> Tensor:
- self_outputs = self.self.forward(hidden_states)
+ def forward(self, hidden_states: Tensor, attention_mask: None) -> Tensor:
+ self_outputs = self.self.forward(hidden_states, attention_mask=attention_mask)
attention_output = self.output.forward(self_outputs, hidden_states)
return attention_output
@@ -117,8 +121,8 @@ class BertLayer(Module):
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
- def forward(self, hidden_states: Tensor) -> Tensor:
- attention_output = self.attention.forward(hidden_states)
+ def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
+ attention_output = self.attention.forward(hidden_states, attention_mask=attention_mask)
# TODO: Support cross-attention?
# https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
# TODO: Support something similar to `apply_chunking_to_forward`?
@@ -134,9 +138,9 @@ class BertEncoder(Module):
for _ in range(config.num_hidden_layers):
self.layer.append(BertLayer(config))
- def forward(self, hidden_states: Tensor) -> Tensor:
+ def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
for l in self.layer:
- hidden_states = l.forward(hidden_states)
+ hidden_states = l.forward(hidden_states, attention_mask=attention_mask)
return hidden_states
@@ -178,6 +182,13 @@ class BertPooler(Module):
return pooled_output
+def masked_fill(on_false: float, mask: Tensor, on_true: float):
+ shape = mask.shape
+ on_true = candle.tensor(on_true).broadcast_as(shape)
+ on_false = candle.tensor(on_false).broadcast_as(shape)
+ return mask.where_cond(on_true, on_false)
+
+
# https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874
class BertModel(Module):
def __init__(self, config: Config, add_pooling_layer=True) -> None:
@@ -187,8 +198,11 @@ class BertModel(Module):
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
- def forward(self, input_ids: Tensor, token_type_ids: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
+ def forward(self, input_ids: Tensor, token_type_ids: Tensor, attention_mask=None) -> Tuple[Tensor, Optional[Tensor]]:
+ if attention_mask is not None:
+ # Replace 0s with -inf, and 1s with 0s.
+ attention_mask = masked_fill(float("-inf"), attention_mask, 1.0)
embeddings = self.embeddings.forward(input_ids, token_type_ids)
- encoder_out = self.encoder.forward(embeddings)
+ encoder_out = self.encoder.forward(embeddings, attention_mask=attention_mask)
pooled_output = self.pooler(encoder_out) if self.pooler is not None else None
return encoder_out, pooled_output