summaryrefslogtreecommitdiff
path: root/candle-pyo3/e5.py
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3/e5.py')
-rw-r--r--candle-pyo3/e5.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/candle-pyo3/e5.py b/candle-pyo3/e5.py
index 8ca48219..9b54ebcb 100644
--- a/candle-pyo3/e5.py
+++ b/candle-pyo3/e5.py
@@ -50,7 +50,8 @@ if __name__ == "__main__":
tokenized = tokenizer(sentences, padding=True)
tokens = Tensor(tokenized["input_ids"])
token_type_ids = Tensor(tokenized["token_type_ids"])
- encoder_out, _ = model.forward(tokens, token_type_ids)
+ attention_mask = Tensor(tokenized["attention_mask"])
+ encoder_out, _ = model.forward(tokens, token_type_ids, attention_mask=attention_mask)
hf_tokenized = tokenizer(sentences, padding=True, return_tensors="pt")
hf_result = hf_model(**hf_tokenized)["last_hidden_state"]