summaryrefslogtreecommitdiff
path: root/candle-pyo3
diff options
context:
space:
mode:
authorLukas Kreussel <65088241+LLukas22@users.noreply.github.com>2023-10-20 20:05:14 +0200
committerGitHub <noreply@github.com>2023-10-20 19:05:14 +0100
commitcfb423ab761fcb2ae3b9e36a18b0f6e5dd7cd253 (patch)
tree6819191f120ca28ce22c417778ff2a3747af13df /candle-pyo3
parent7366aeac21d2be65bddf8691223f654c0ed8fd0b (diff)
downloadcandle-cfb423ab761fcb2ae3b9e36a18b0f6e5dd7cd253.tar.gz
candle-cfb423ab761fcb2ae3b9e36a18b0f6e5dd7cd253.tar.bz2
candle-cfb423ab761fcb2ae3b9e36a18b0f6e5dd7cd253.zip
PyO3: Add CI (#1135)
* Add PyO3 ci * Update python.yml * Format `bert.py`
Diffstat (limited to 'candle-pyo3')
-rw-r--r--candle-pyo3/py_src/candle/models/bert.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/candle-pyo3/py_src/candle/models/bert.py b/candle-pyo3/py_src/candle/models/bert.py
index 36e242ad..ecb238d8 100644
--- a/candle-pyo3/py_src/candle/models/bert.py
+++ b/candle-pyo3/py_src/candle/models/bert.py
@@ -59,8 +59,7 @@ class BertSelfAttention(Module):
attention_scores = attention_scores / float(self.attention_head_size) ** 0.5
if attention_mask is not None:
b_size, _, _, last_dim = attention_scores.shape
- attention_scores = attention_scores.broadcast_add(
- attention_mask.reshape((b_size, 1, 1, last_dim)))
+ attention_scores = attention_scores.broadcast_add(attention_mask.reshape((b_size, 1, 1, last_dim)))
attention_probs = F.softmax(attention_scores, dim=-1)
context_layer = attention_probs.matmul(value)
@@ -198,7 +197,9 @@ class BertModel(Module):
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
- def forward(self, input_ids: Tensor, token_type_ids: Tensor, attention_mask=None) -> Tuple[Tensor, Optional[Tensor]]:
+ def forward(
+ self, input_ids: Tensor, token_type_ids: Tensor, attention_mask=None
+ ) -> Tuple[Tensor, Optional[Tensor]]:
if attention_mask is not None:
# Replace 0s with -inf, and 1s with 0s.
attention_mask = masked_fill(float("-inf"), attention_mask, 1.0)