diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-13 19:53:40 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-13 18:53:40 +0100 |
commit | 75989fc3b7ad06f6216b3aab62a2f3a7fcf4ebba (patch) | |
tree | ca9f0e48ae9d08fd9f61dfedf1cf0feb96f246b3 /candle-pyo3 | |
parent | 07af87a1d801852645966d89bd193808ff7c5b35 (diff) | |
download | candle-75989fc3b7ad06f6216b3aab62a2f3a7fcf4ebba.tar.gz candle-75989fc3b7ad06f6216b3aab62a2f3a7fcf4ebba.tar.bz2 candle-75989fc3b7ad06f6216b3aab62a2f3a7fcf4ebba.zip |
Use an attention mask in the e5 padding case. (#1085)
Diffstat (limited to 'candle-pyo3')
-rw-r--r-- | candle-pyo3/e5.py | 3 | ||||
-rw-r--r-- | candle-pyo3/py_src/candle/models/bert.py | 34 |
2 files changed, 26 insertions, 11 deletions
diff --git a/candle-pyo3/e5.py b/candle-pyo3/e5.py index 8ca48219..9b54ebcb 100644 --- a/candle-pyo3/e5.py +++ b/candle-pyo3/e5.py @@ -50,7 +50,8 @@ if __name__ == "__main__": tokenized = tokenizer(sentences, padding=True) tokens = Tensor(tokenized["input_ids"]) token_type_ids = Tensor(tokenized["token_type_ids"]) - encoder_out, _ = model.forward(tokens, token_type_ids) + attention_mask = Tensor(tokenized["attention_mask"]) + encoder_out, _ = model.forward(tokens, token_type_ids, attention_mask=attention_mask) hf_tokenized = tokenizer(sentences, padding=True, return_tensors="pt") hf_result = hf_model(**hf_tokenized)["last_hidden_state"] diff --git a/candle-pyo3/py_src/candle/models/bert.py b/candle-pyo3/py_src/candle/models/bert.py index 0a773f93..36e242ad 100644 --- a/candle-pyo3/py_src/candle/models/bert.py +++ b/candle-pyo3/py_src/candle/models/bert.py @@ -46,7 +46,7 @@ class BertSelfAttention(Module): x = x.reshape(new_x_shape).transpose(1, 2) return x.contiguous() - def forward(self, hidden_states: Tensor) -> Tensor: + def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor: query = self.query.forward(hidden_states) key = self.key.forward(hidden_states) value = self.value.forward(hidden_states) @@ -56,7 +56,11 @@ class BertSelfAttention(Module): value = self.transpose_for_scores(value) attention_scores = query.matmul(key.t()) - attention_scores = attention_scores / (float(self.attention_head_size) ** 0.5) + attention_scores = attention_scores / float(self.attention_head_size) ** 0.5 + if attention_mask is not None: + b_size, _, _, last_dim = attention_scores.shape + attention_scores = attention_scores.broadcast_add( + attention_mask.reshape((b_size, 1, 1, last_dim))) attention_probs = F.softmax(attention_scores, dim=-1) context_layer = attention_probs.matmul(value) @@ -82,8 +86,8 @@ class BertAttention(Module): self.self = BertSelfAttention(config) self.output = BertSelfOutput(config) - def forward(self, hidden_states: Tensor) -> Tensor: - self_outputs = self.self.forward(hidden_states) + def forward(self, hidden_states: Tensor, attention_mask: None) -> Tensor: + self_outputs = self.self.forward(hidden_states, attention_mask=attention_mask) attention_output = self.output.forward(self_outputs, hidden_states) return attention_output @@ -117,8 +121,8 @@ class BertLayer(Module): self.intermediate = BertIntermediate(config) self.output = BertOutput(config) - def forward(self, hidden_states: Tensor) -> Tensor: - attention_output = self.attention.forward(hidden_states) + def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor: + attention_output = self.attention.forward(hidden_states, attention_mask=attention_mask) # TODO: Support cross-attention? # https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523 # TODO: Support something similar to `apply_chunking_to_forward`? @@ -134,9 +138,9 @@ class BertEncoder(Module): for _ in range(config.num_hidden_layers): self.layer.append(BertLayer(config)) - def forward(self, hidden_states: Tensor) -> Tensor: + def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor: for l in self.layer: - hidden_states = l.forward(hidden_states) + hidden_states = l.forward(hidden_states, attention_mask=attention_mask) return hidden_states @@ -178,6 +182,13 @@ class BertPooler(Module): return pooled_output +def masked_fill(on_false: float, mask: Tensor, on_true: float): + shape = mask.shape + on_true = candle.tensor(on_true).broadcast_as(shape) + on_false = candle.tensor(on_false).broadcast_as(shape) + return mask.where_cond(on_true, on_false) + + # https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874 class BertModel(Module): def __init__(self, config: Config, add_pooling_layer=True) -> None: @@ -187,8 +198,11 @@ class BertModel(Module): self.encoder = BertEncoder(config) self.pooler = BertPooler(config) if add_pooling_layer else None - def forward(self, input_ids: Tensor, token_type_ids: Tensor) -> Tuple[Tensor, Optional[Tensor]]: + def forward(self, input_ids: Tensor, token_type_ids: Tensor, attention_mask=None) -> Tuple[Tensor, Optional[Tensor]]: + if attention_mask is not None: + # Replace 0s with -inf, and 1s with 0s. + attention_mask = masked_fill(float("-inf"), attention_mask, 1.0) embeddings = self.embeddings.forward(input_ids, token_type_ids) - encoder_out = self.encoder.forward(embeddings) + encoder_out = self.encoder.forward(embeddings, attention_mask=attention_mask) pooled_output = self.pooler(encoder_out) if self.pooler is not None else None return encoder_out, pooled_output |