summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.github/workflows/python.yml62
-rw-r--r--candle-pyo3/py_src/candle/models/bert.py7
2 files changed, 66 insertions, 3 deletions
diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml
new file mode 100644
index 00000000..bf85f5e5
--- /dev/null
+++ b/.github/workflows/python.yml
@@ -0,0 +1,62 @@
+name: PyO3-CI
+
+on:
+ workflow_dispatch:
+ push:
+ branches:
+ - main
+ paths:
+ - candle-pyo3/**
+ pull_request:
+ paths:
+ - candle-pyo3/**
+
+jobs:
+ build_and_test:
+ name: Check everything builds & tests
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-latest] # For now, only test on Linux
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v2
+
+ - name: Install Rust
+ uses: actions-rs/toolchain@v1
+ with:
+ toolchain: stable
+
+ - name: Install Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: 3.11
+ architecture: "x64"
+
+ - name: Cache Cargo Registry
+ uses: actions/cache@v1
+ with:
+ path: ~/.cargo/registry
+ key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
+
+ - name: Install
+ working-directory: ./candle-pyo3
+ run: |
+ python -m venv .env
+ source .env/bin/activate
+ pip install -U pip
+ pip install pytest maturin black
+ python -m maturin develop -r
+
+ - name: Check style
+ working-directory: ./candle-pyo3
+ run: |
+ source .env/bin/activate
+ python stub.py --check
+ black --check .
+
+ - name: Run tests
+ working-directory: ./candle-pyo3
+ run: |
+ source .env/bin/activate
+ python -m pytest -s -v tests \ No newline at end of file
diff --git a/candle-pyo3/py_src/candle/models/bert.py b/candle-pyo3/py_src/candle/models/bert.py
index 36e242ad..ecb238d8 100644
--- a/candle-pyo3/py_src/candle/models/bert.py
+++ b/candle-pyo3/py_src/candle/models/bert.py
@@ -59,8 +59,7 @@ class BertSelfAttention(Module):
attention_scores = attention_scores / float(self.attention_head_size) ** 0.5
if attention_mask is not None:
b_size, _, _, last_dim = attention_scores.shape
- attention_scores = attention_scores.broadcast_add(
- attention_mask.reshape((b_size, 1, 1, last_dim)))
+ attention_scores = attention_scores.broadcast_add(attention_mask.reshape((b_size, 1, 1, last_dim)))
attention_probs = F.softmax(attention_scores, dim=-1)
context_layer = attention_probs.matmul(value)
@@ -198,7 +197,9 @@ class BertModel(Module):
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
- def forward(self, input_ids: Tensor, token_type_ids: Tensor, attention_mask=None) -> Tuple[Tensor, Optional[Tensor]]:
+ def forward(
+ self, input_ids: Tensor, token_type_ids: Tensor, attention_mask=None
+ ) -> Tuple[Tensor, Optional[Tensor]]:
if attention_mask is not None:
# Replace 0s with -inf, and 1s with 0s.
attention_mask = masked_fill(float("-inf"), attention_mask, 1.0)