diff options
Diffstat (limited to 'candle-examples/examples')
-rw-r--r-- | candle-examples/examples/bert/main.rs | 8 | ||||
-rw-r--r-- | candle-examples/examples/llama/model.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/musicgen/nn.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/musicgen/t5_model.rs | 2 |
4 files changed, 7 insertions, 7 deletions
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 1c3c429b..d7df5ae3 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -604,16 +604,16 @@ fn main() -> Result<()> { println!("generated embeddings {:?}", embeddings.shape()); // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) let (_n_sentence, n_tokens, _hidden_size) = embeddings.shape().r3()?; - let embeddings = (embeddings.sum(&[1])? / (n_tokens as f64))?.squeeze(1)?; + let embeddings = (embeddings.sum_keepdim(&[1])? / (n_tokens as f64))?.squeeze(1)?; println!("pooled embeddings {:?}", embeddings.shape()); let mut similarities = vec![]; for i in 0..n_sentences { let e_i = embeddings.get(i)?; for j in (i + 1)..n_sentences { let e_j = embeddings.get(j)?; - let sum_ij = (&e_i * &e_j)?.sum_all()?.reshape(())?.to_scalar::<f32>()?; - let sum_i2 = (&e_i * &e_i)?.sum_all()?.reshape(())?.to_scalar::<f32>()?; - let sum_j2 = (&e_j * &e_j)?.sum_all()?.reshape(())?.to_scalar::<f32>()?; + let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?; + let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?; + let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?; let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt(); similarities.push((cosine_similarity, i, j)) } diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index 04397d1e..57f339b0 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -95,7 +95,7 @@ impl RmsNorm { // This is a no-op if x's dtype is already f32. let x = x.to_dtype(DType::F32)?; let (b_sz, seq_len, hidden_size) = x.shape().r3()?; - let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?; + let norm_x = (x.sqr()?.sum_keepdim(&[2])? / hidden_size as f64)?; let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?; let size = self.scale.shape().r1()?; diff --git a/candle-examples/examples/musicgen/nn.rs b/candle-examples/examples/musicgen/nn.rs index 5c90dd4e..652c47a7 100644 --- a/candle-examples/examples/musicgen/nn.rs +++ b/candle-examples/examples/musicgen/nn.rs @@ -70,7 +70,7 @@ pub fn conv1d_weight_norm( ) -> Result<Conv1d> { let weight_g = vb.get((out_c, 1, 1), "weight_g")?; let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?; - let norm_v = (&weight_v * &weight_v)?.sum(&[1, 2])?.sqrt()?; + let norm_v = weight_v.sqr()?.sum_keepdim(&[1, 2])?.sqrt()?; let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?; let bias = vb.get(out_c, "bias")?; Ok(Conv1d::new(weight, Some(bias), config)) diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs index 0444f360..2119cf9b 100644 --- a/candle-examples/examples/musicgen/t5_model.rs +++ b/candle-examples/examples/musicgen/t5_model.rs @@ -98,7 +98,7 @@ impl T5LayerNorm { let dtype = xs.dtype(); let xs_f32 = xs.to_dtype(DType::F32)?; let xs2_f32 = (&xs_f32 * &xs_f32)?; - let sum_xs2_f32 = xs2_f32.sum(&[xs.rank() - 1])?; + let sum_xs2_f32 = xs2_f32.sum_keepdim(&[xs.rank() - 1])?; let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?; let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?; let xs = xs.to_dtype(dtype)?; |