diff options
Diffstat (limited to 'candle-wasm-examples/bert')
-rw-r--r-- | candle-wasm-examples/bert/src/bin/m.rs | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/candle-wasm-examples/bert/src/bin/m.rs b/candle-wasm-examples/bert/src/bin/m.rs index 92617f15..9e5cf913 100644 --- a/candle-wasm-examples/bert/src/bin/m.rs +++ b/candle-wasm-examples/bert/src/bin/m.rs @@ -55,11 +55,21 @@ impl Model { Tensor::new(tokens.as_slice(), device) }) .collect::<Result<Vec<_>, _>>()?; + let attention_mask: Vec<Tensor> = tokens + .iter() + .map(|tokens| { + let tokens = tokens.get_attention_mask().to_vec(); + Tensor::new(tokens.as_slice(), device) + }) + .collect::<Result<Vec<_>, _>>()?; let token_ids = Tensor::stack(&token_ids, 0)?; + let attention_mask = Tensor::stack(&attention_mask, 0)?; let token_type_ids = token_ids.zeros_like()?; console_log!("running inference on batch {:?}", token_ids.shape()); - let embeddings = self.bert.forward(&token_ids, &token_type_ids)?; + let embeddings = self + .bert + .forward(&token_ids, &token_type_ids, Some(&attention_mask))?; console_log!("generated embeddings {:?}", embeddings.shape()); // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?; |