diff options
Diffstat (limited to 'candle-examples/examples/bert/main.rs')
-rw-r--r-- | candle-examples/examples/bert/main.rs | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 1c3c429b..d7df5ae3 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -604,16 +604,16 @@ fn main() -> Result<()> { println!("generated embeddings {:?}", embeddings.shape()); // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) let (_n_sentence, n_tokens, _hidden_size) = embeddings.shape().r3()?; - let embeddings = (embeddings.sum(&[1])? / (n_tokens as f64))?.squeeze(1)?; + let embeddings = (embeddings.sum_keepdim(&[1])? / (n_tokens as f64))?.squeeze(1)?; println!("pooled embeddings {:?}", embeddings.shape()); let mut similarities = vec![]; for i in 0..n_sentences { let e_i = embeddings.get(i)?; for j in (i + 1)..n_sentences { let e_j = embeddings.get(j)?; - let sum_ij = (&e_i * &e_j)?.sum_all()?.reshape(())?.to_scalar::<f32>()?; - let sum_i2 = (&e_i * &e_i)?.sum_all()?.reshape(())?.to_scalar::<f32>()?; - let sum_j2 = (&e_j * &e_j)?.sum_all()?.reshape(())?.to_scalar::<f32>()?; + let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?; + let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?; + let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?; let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt(); similarities.push((cosine_similarity, i, j)) } |