diff options
author | Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> | 2023-10-20 20:05:14 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-20 19:05:14 +0100 |
commit | cfb423ab761fcb2ae3b9e36a18b0f6e5dd7cd253 (patch) | |
tree | 6819191f120ca28ce22c417778ff2a3747af13df /candle-pyo3 | |
parent | 7366aeac21d2be65bddf8691223f654c0ed8fd0b (diff) | |
download | candle-cfb423ab761fcb2ae3b9e36a18b0f6e5dd7cd253.tar.gz candle-cfb423ab761fcb2ae3b9e36a18b0f6e5dd7cd253.tar.bz2 candle-cfb423ab761fcb2ae3b9e36a18b0f6e5dd7cd253.zip |
PyO3: Add CI (#1135)
* Add PyO3 ci
* Update python.yml
* Format `bert.py`
Diffstat (limited to 'candle-pyo3')
-rw-r--r-- | candle-pyo3/py_src/candle/models/bert.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/candle-pyo3/py_src/candle/models/bert.py b/candle-pyo3/py_src/candle/models/bert.py index 36e242ad..ecb238d8 100644 --- a/candle-pyo3/py_src/candle/models/bert.py +++ b/candle-pyo3/py_src/candle/models/bert.py @@ -59,8 +59,7 @@ class BertSelfAttention(Module): attention_scores = attention_scores / float(self.attention_head_size) ** 0.5 if attention_mask is not None: b_size, _, _, last_dim = attention_scores.shape - attention_scores = attention_scores.broadcast_add( - attention_mask.reshape((b_size, 1, 1, last_dim))) + attention_scores = attention_scores.broadcast_add(attention_mask.reshape((b_size, 1, 1, last_dim))) attention_probs = F.softmax(attention_scores, dim=-1) context_layer = attention_probs.matmul(value) @@ -198,7 +197,9 @@ class BertModel(Module): self.encoder = BertEncoder(config) self.pooler = BertPooler(config) if add_pooling_layer else None - def forward(self, input_ids: Tensor, token_type_ids: Tensor, attention_mask=None) -> Tuple[Tensor, Optional[Tensor]]: + def forward( + self, input_ids: Tensor, token_type_ids: Tensor, attention_mask=None + ) -> Tuple[Tensor, Optional[Tensor]]: if attention_mask is not None: # Replace 0s with -inf, and 1s with 0s. attention_mask = masked_fill(float("-inf"), attention_mask, 1.0) |