diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-10 14:05:55 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-10 13:05:55 +0100 |
commit | 385f0d261c9199b83373fc10b7b52a1fef67a9d0 (patch) | |
tree | 283c4fe3c8e5950d38f769d3bbc81012d7f84151 | |
parent | b765f2c37fe3e8cde2a041c162cdcc2e9fbb5033 (diff) | |
download | candle-385f0d261c9199b83373fc10b7b52a1fef67a9d0.tar.gz candle-385f0d261c9199b83373fc10b7b52a1fef67a9d0.tar.bz2 candle-385f0d261c9199b83373fc10b7b52a1fef67a9d0.zip |
Normalize embeddings in the bert example. (#390)
-rw-r--r-- | candle-examples/examples/bert/main.rs | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 79c78968..574755ed 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -39,6 +39,10 @@ struct Args { /// The number of times to run the prompt. #[arg(long, default_value = "1")] n: usize, + + /// L2 normalization for embeddings. + #[arg(long, default_value = "true")] + normalize_embeddings: bool, } impl Args { @@ -164,7 +168,13 @@ fn main() -> Result<()> { // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?; let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?; + let embeddings = if args.normalize_embeddings { + normalize_l2(&embeddings)? + } else { + embeddings + }; println!("pooled embeddings {:?}", embeddings.shape()); + let mut similarities = vec![]; for i in 0..n_sentences { let e_i = embeddings.get(i)?; @@ -184,3 +194,7 @@ fn main() -> Result<()> { } Ok(()) } + +pub fn normalize_l2(v: &Tensor) -> Result<Tensor> { + Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?) +} |