summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/bert
diff options
context:
space:
mode:
authorZheng Li <875543533@qq.com>2024-08-01 14:26:19 +0800
committerGitHub <noreply@github.com>2024-08-01 08:26:19 +0200
commit4a52aeb4372ae42d469b0c477be12d7d79c28bf5 (patch)
tree42f17fd26419f6b45b6df82bfb68797e3a07f0f9 /candle-wasm-examples/bert
parent24d54d0ff90ecc701f4a41770482a2611da05d15 (diff)
downloadcandle-4a52aeb4372ae42d469b0c477be12d7d79c28bf5.tar.gz
candle-4a52aeb4372ae42d469b0c477be12d7d79c28bf5.tar.bz2
candle-4a52aeb4372ae42d469b0c477be12d7d79c28bf5.zip
bert attention mask (#1934)
* bert attention mask * Allow for using None as a mask. * Revert part of the changes so that the proper default mask applies. * Cosmetic change. * Another cosmetic tweak. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-wasm-examples/bert')
-rw-r--r--candle-wasm-examples/bert/src/bin/m.rs12
1 files changed, 11 insertions, 1 deletions
diff --git a/candle-wasm-examples/bert/src/bin/m.rs b/candle-wasm-examples/bert/src/bin/m.rs
index 92617f15..9e5cf913 100644
--- a/candle-wasm-examples/bert/src/bin/m.rs
+++ b/candle-wasm-examples/bert/src/bin/m.rs
@@ -55,11 +55,21 @@ impl Model {
Tensor::new(tokens.as_slice(), device)
})
.collect::<Result<Vec<_>, _>>()?;
+ let attention_mask: Vec<Tensor> = tokens
+ .iter()
+ .map(|tokens| {
+ let tokens = tokens.get_attention_mask().to_vec();
+ Tensor::new(tokens.as_slice(), device)
+ })
+ .collect::<Result<Vec<_>, _>>()?;
let token_ids = Tensor::stack(&token_ids, 0)?;
+ let attention_mask = Tensor::stack(&attention_mask, 0)?;
let token_type_ids = token_ids.zeros_like()?;
console_log!("running inference on batch {:?}", token_ids.shape());
- let embeddings = self.bert.forward(&token_ids, &token_type_ids)?;
+ let embeddings = self
+ .bert
+ .forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
console_log!("generated embeddings {:?}", embeddings.shape());
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;