summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-10 14:05:55 +0200
committerGitHub <noreply@github.com>2023-08-10 13:05:55 +0100
commit385f0d261c9199b83373fc10b7b52a1fef67a9d0 (patch)
tree283c4fe3c8e5950d38f769d3bbc81012d7f84151
parentb765f2c37fe3e8cde2a041c162cdcc2e9fbb5033 (diff)
downloadcandle-385f0d261c9199b83373fc10b7b52a1fef67a9d0.tar.gz
candle-385f0d261c9199b83373fc10b7b52a1fef67a9d0.tar.bz2
candle-385f0d261c9199b83373fc10b7b52a1fef67a9d0.zip
Normalize embeddings in the bert example. (#390)
-rw-r--r--candle-examples/examples/bert/main.rs14
1 files changed, 14 insertions, 0 deletions
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs
index 79c78968..574755ed 100644
--- a/candle-examples/examples/bert/main.rs
+++ b/candle-examples/examples/bert/main.rs
@@ -39,6 +39,10 @@ struct Args {
/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,
+
+ /// L2 normalization for embeddings.
+ #[arg(long, default_value = "true")]
+ normalize_embeddings: bool,
}
impl Args {
@@ -164,7 +168,13 @@ fn main() -> Result<()> {
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
+ let embeddings = if args.normalize_embeddings {
+ normalize_l2(&embeddings)?
+ } else {
+ embeddings
+ };
println!("pooled embeddings {:?}", embeddings.shape());
+
let mut similarities = vec![];
for i in 0..n_sentences {
let e_i = embeddings.get(i)?;
@@ -184,3 +194,7 @@ fn main() -> Result<()> {
}
Ok(())
}
+
+pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
+ Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
+}