diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-13 21:32:32 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-13 21:32:32 +0100 |
commit | 2bfa791336b320b96d392aba83cbd4cee87173e3 (patch) | |
tree | a3127719a64cf5cfbf38f5f8be859afd2dc6118e | |
parent | 57be3638d8c10304629f6859d183fb192858f3a3 (diff) | |
download | candle-2bfa791336b320b96d392aba83cbd4cee87173e3.tar.gz candle-2bfa791336b320b96d392aba83cbd4cee87173e3.tar.bz2 candle-2bfa791336b320b96d392aba83cbd4cee87173e3.zip |
Use the same default as pytorch for sum. (#164)
-rw-r--r-- | candle-core/examples/cuda_basics.rs | 4 | ||||
-rw-r--r-- | candle-core/examples/cuda_sum_benchmark.rs | 16 | ||||
-rw-r--r-- | candle-core/src/backprop.rs | 6 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 40 | ||||
-rw-r--r-- | candle-core/tests/grad_tests.rs | 4 | ||||
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 81 | ||||
-rw-r--r-- | candle-examples/examples/bert/main.rs | 8 | ||||
-rw-r--r-- | candle-examples/examples/llama/model.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/musicgen/nn.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/musicgen/t5_model.rs | 2 | ||||
-rw-r--r-- | candle-nn/src/layer_norm.rs | 4 | ||||
-rw-r--r-- | candle-nn/tests/layer_norm.rs | 4 | ||||
-rw-r--r-- | candle-pyo3/src/lib.rs | 6 |
13 files changed, 123 insertions, 56 deletions
diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs index aeee541a..37a66cb5 100644 --- a/candle-core/examples/cuda_basics.rs +++ b/candle-core/examples/cuda_basics.rs @@ -7,9 +7,9 @@ use candle::{Device, Tensor}; fn main() -> Result<()> { let device = Device::new_cuda(0)?; let t = Tensor::new(&[[1f32, 2., 3., 4.2]], &device)?; - let sum = t.sum(&[0])?; + let sum = t.sum_keepdim(&[0])?; println!("{sum}"); - let sum = t.sum(&[1])?; + let sum = t.sum_keepdim(&[1])?; println!("{sum}"); Ok(()) } diff --git a/candle-core/examples/cuda_sum_benchmark.rs b/candle-core/examples/cuda_sum_benchmark.rs index 09d0099d..86a1691d 100644 --- a/candle-core/examples/cuda_sum_benchmark.rs +++ b/candle-core/examples/cuda_sum_benchmark.rs @@ -27,18 +27,18 @@ fn main() -> Result<()> { let xys_cpu = cos_sin(n, &Device::Cpu)?; let xys = cos_sin(n, &device)?; println!("{xys_cpu:?} {xys:?}"); - let sum_cpu = xys_cpu.sum(&[1])?; - println!("{sum_cpu}"); - let sum = xys.sum(&[1])?; - println!("{sum}"); + let sum_keepdim_cpu = xys_cpu.sum_keepdim(&[1])?; + println!("{sum_keepdim_cpu}"); + let sum_keepdim = xys.sum_keepdim(&[1])?; + println!("{sum_keepdim}"); let start = std::time::Instant::now(); let n_iters = 100; let mut v = 0f32; for _i in 0..n_iters { - let sum = xys.sum(&[1])?; - let sum = sum.sum(&[0])?; - let sum: f32 = sum.reshape(&[])?.to_scalar()?; - v += sum; + let sum_keepdim = xys.sum_keepdim(&[1])?; + let sum_keepdim = sum_keepdim.sum_keepdim(&[0])?; + let sum_keepdim: f32 = sum_keepdim.reshape(&[])?.to_scalar()?; + v += sum_keepdim; } let elapsed = start.elapsed(); if v > 0. { diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 2711da85..c72f603f 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -195,11 +195,7 @@ impl Tensor { } } - let mut arg_grad = grad.sum(sum_dims.as_slice())?; - // sum_dims has increasing values. - for &dim in sum_dims.iter().rev() { - arg_grad = arg_grad.squeeze(dim)? - } + let arg_grad = grad.sum(sum_dims.as_slice())?; let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.broadcast_add(&arg_grad)? } diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 481a6851..af3675cc 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -572,7 +572,7 @@ impl Tensor { // We do not have a cuda kernel for divide_by_sum_over_dim so split // the operation. let exp = self.exp()?; - let sum_exp = exp.sum(&[dim])?; + let sum_exp = exp.sum_keepdim(&[dim])?; exp.broadcast_div(&sum_exp) } else { let shape = self.shape(); @@ -591,21 +591,21 @@ impl Tensor { /// Returns the sum of all elements in the input tensor. The sum is performed over all the /// input dimensions. /// - /// The resulting tensor as a shape that is similar to the shape of the input tensor, except + /// The resulting tensor has a shape that is similar to the shape of the input tensor, except /// that the number of elements for each dimension index in `sum_dims` is 1. /// /// ```rust /// use candle::{Tensor, Device}; /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?; - /// let s = a.sum(&[0])?; + /// let s = a.sum_keepdim(&[0])?; /// assert_eq!(s.to_vec2::<f32>()?, &[[2., 4.]]); - /// let s = a.sum(&[1])?; + /// let s = a.sum_keepdim(&[1])?; /// assert_eq!(s.to_vec2::<f32>()?, &[[1.], [5.]]); - /// let s = a.sum(&[0, 1])?; + /// let s = a.sum_keepdim(&[0, 1])?; /// assert_eq!(s.to_vec2::<f32>()?, &[[6.]]); /// # Ok::<(), candle::Error>(()) /// ``` - pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> { + pub fn sum_keepdim(&self, sum_dims: &[usize]) -> Result<Self> { for &dim in sum_dims { self.check_dim(dim, "sum")?; } @@ -622,6 +622,32 @@ impl Tensor { Ok(from_storage(storage, dims, op, false)) } + /// Returns the sum of all elements in the input tensor. The sum is performed over all the + /// input dimensions and compared to `sum_keepdim` these dimensions are squeezed rather than + /// kept. + pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> { + let sum = self.sum_keepdim(sum_dims)?; + match sum_dims { + [] => Ok(sum), + [i] => sum.squeeze(*i), + sum_dims => { + let dims = sum + .dims() + .iter() + .enumerate() + .filter_map(|(dim_idx, &v)| { + if sum_dims.contains(&dim_idx) { + None + } else { + Some(v) + } + }) + .collect::<Vec<_>>(); + sum.reshape(dims) + } + } + } + /// Applies a 1D convolution over the input tensor. pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> { let (c_out, c_in_k, k_size) = kernel.shape().r3()?; @@ -936,7 +962,7 @@ impl Tensor { /// ``` pub fn sum_all(&self) -> Result<Tensor> { let dims: Vec<_> = (0..self.rank()).collect(); - self.sum(&dims)?.reshape(()) + self.sum_keepdim(&dims)?.reshape(()) } fn flatten_<D1: Dim, D2: Dim>( diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index 501b0d44..241efc69 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -19,7 +19,7 @@ fn simple_grad(device: &Device) -> Result<()> { fn sum_grad(device: &Device) -> Result<()> { let x = Var::new(&[3f32, 1., 4.], device)?; let x = x.as_tensor(); - let y = (x.sqr()?.sum(&[0])? * 2.)?; + let y = (x.sqr()?.sum_keepdim(&[0])? * 2.)?; let grads = y.backward()?; let grad_x = grads.get(x).context("no grad for x")?; assert_eq!(y.to_vec1::<f32>()?, [52.]); @@ -27,7 +27,7 @@ fn sum_grad(device: &Device) -> Result<()> { assert_eq!(grad_x.to_vec1::<f32>()?, &[12., 4., 16.]); // Same test as before but squeezing on the last dimension. - let y = (x.sqr()?.sum(&[0])? * 2.)?.squeeze(0)?; + let y = (x.sqr()?.sum_keepdim(&[0])? * 2.)?.squeeze(0)?; let grads = y.backward()?; let grad_x = grads.get(x).context("no grad for x")?; assert_eq!(y.to_scalar::<f32>()?, 52.); diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 7e4467d1..b9e8a982 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -108,56 +108,99 @@ fn sum(device: &Device) -> Result<()> { let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]]; let tensor = Tensor::new(data, device)?; assert_eq!( - tensor.sum(&[2])?.to_vec3::<u32>()?, + tensor.sum_keepdim(&[2])?.to_vec3::<u32>()?, &[[[8], [15]], [[10], [18]]] ); assert_eq!( - tensor.sum(&[0])?.to_vec3::<u32>()?, + tensor.sum_keepdim(&[0])?.to_vec3::<u32>()?, &[[[5, 2, 11], [9, 7, 17]]], ); - assert_eq!(tensor.sum(&[0, 2, 1])?.to_vec3::<u32>()?, &[[[51]]],); + assert_eq!(tensor.sum_keepdim(&[0, 2, 1])?.to_vec3::<u32>()?, &[[[51]]],); assert_eq!( - tensor.t()?.sum(&[1])?.t()?.to_vec3::<u32>()?, + tensor.t()?.sum_keepdim(&[1])?.t()?.to_vec3::<u32>()?, &[[[8], [15]], [[10], [18]]] ); assert_eq!( - tensor.sum(&[2, 1])?.to_vec3::<u32>()?, + tensor.sum_keepdim(&[2, 1])?.to_vec3::<u32>()?, &[[[8 + 15]], [[10 + 18]]] ); let data: Vec<u32> = (0..4000u32).collect(); let tensor = Tensor::new(data.as_slice(), device)?; - assert_eq!(tensor.sum(&[0])?.to_vec1::<u32>()?, &[7998000]); + assert_eq!(tensor.sum_keepdim(&[0])?.to_vec1::<u32>()?, &[7998000]); let tensor = tensor.reshape((2000, 2))?; - assert_eq!(tensor.sum(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]); - assert_eq!(tensor.sum(&[0])?.sum(&[1])?.to_vec2::<u32>()?, &[[7998000]]); - assert_eq!(tensor.sum(&[1])?.sum(&[0])?.to_vec2::<u32>()?, &[[7998000]]); - assert_eq!(tensor.sum(&[0])?.to_vec2::<u32>()?, &[[3998000, 4000000]]); + assert_eq!(tensor.sum_keepdim(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]); + assert_eq!( + tensor + .sum_keepdim(&[0])? + .sum_keepdim(&[1])? + .to_vec2::<u32>()?, + &[[7998000]] + ); + assert_eq!( + tensor + .sum_keepdim(&[1])? + .sum_keepdim(&[0])? + .to_vec2::<u32>()?, + &[[7998000]] + ); + assert_eq!( + tensor.sum_keepdim(&[0])?.to_vec2::<u32>()?, + &[[3998000, 4000000]] + ); // Make the tensor non contiguous. let tensor = tensor.t()?.contiguous()?.t()?; - assert_eq!(tensor.sum(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]); - assert_eq!(tensor.sum(&[0])?.sum(&[1])?.to_vec2::<u32>()?, &[[7998000]]); - assert_eq!(tensor.sum(&[1])?.sum(&[0])?.to_vec2::<u32>()?, &[[7998000]]); - assert_eq!(tensor.sum(&[0])?.to_vec2::<u32>()?, &[[3998000, 4000000]]); + assert_eq!(tensor.sum_keepdim(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]); + assert_eq!( + tensor + .sum_keepdim(&[0])? + .sum_keepdim(&[1])? + .to_vec2::<u32>()?, + &[[7998000]] + ); + assert_eq!( + tensor + .sum_keepdim(&[1])? + .sum_keepdim(&[0])? + .to_vec2::<u32>()?, + &[[7998000]] + ); + assert_eq!( + tensor.sum_keepdim(&[0])?.to_vec2::<u32>()?, + &[[3998000, 4000000]] + ); let t1 = tensor.reshape((200, 5, 4))?; let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?; for tensor in [t1, t2] { - assert_eq!(tensor.sum(&[0, 1, 2])?.to_vec3::<u32>()?, &[[[7998000]]]); assert_eq!( - tensor.sum(&[0])?.sum(&[2])?.sum(&[1])?.to_vec3::<u32>()?, + tensor.sum_keepdim(&[0, 1, 2])?.to_vec3::<u32>()?, + &[[[7998000]]] + ); + assert_eq!( + tensor + .sum_keepdim(&[0])? + .sum_keepdim(&[2])? + .sum_keepdim(&[1])? + .to_vec3::<u32>()?, &[[[7998000]]] ); assert_eq!( - tensor.sum(&[0])?.sum(&[1, 2])?.to_vec3::<u32>()?, + tensor + .sum_keepdim(&[0])? + .sum_keepdim(&[1, 2])? + .to_vec3::<u32>()?, &[[[7998000]]] ); assert_eq!( - tensor.sum(&[1])?.sum(&[0, 2])?.to_vec3::<u32>()?, + tensor + .sum_keepdim(&[1])? + .sum_keepdim(&[0, 2])? + .to_vec3::<u32>()?, &[[[7998000]]] ); assert_eq!( - tensor.sum(&[0])?.to_vec3::<u32>()?, + tensor.sum_keepdim(&[0])?.to_vec3::<u32>()?, &[[ [398000, 398200, 398400, 398600], [398800, 399000, 399200, 399400], diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 1c3c429b..d7df5ae3 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -604,16 +604,16 @@ fn main() -> Result<()> { println!("generated embeddings {:?}", embeddings.shape()); // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) let (_n_sentence, n_tokens, _hidden_size) = embeddings.shape().r3()?; - let embeddings = (embeddings.sum(&[1])? / (n_tokens as f64))?.squeeze(1)?; + let embeddings = (embeddings.sum_keepdim(&[1])? / (n_tokens as f64))?.squeeze(1)?; println!("pooled embeddings {:?}", embeddings.shape()); let mut similarities = vec![]; for i in 0..n_sentences { let e_i = embeddings.get(i)?; for j in (i + 1)..n_sentences { let e_j = embeddings.get(j)?; - let sum_ij = (&e_i * &e_j)?.sum_all()?.reshape(())?.to_scalar::<f32>()?; - let sum_i2 = (&e_i * &e_i)?.sum_all()?.reshape(())?.to_scalar::<f32>()?; - let sum_j2 = (&e_j * &e_j)?.sum_all()?.reshape(())?.to_scalar::<f32>()?; + let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?; + let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?; + let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?; let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt(); similarities.push((cosine_similarity, i, j)) } diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index 04397d1e..57f339b0 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -95,7 +95,7 @@ impl RmsNorm { // This is a no-op if x's dtype is already f32. let x = x.to_dtype(DType::F32)?; let (b_sz, seq_len, hidden_size) = x.shape().r3()?; - let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?; + let norm_x = (x.sqr()?.sum_keepdim(&[2])? / hidden_size as f64)?; let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?; let size = self.scale.shape().r1()?; diff --git a/candle-examples/examples/musicgen/nn.rs b/candle-examples/examples/musicgen/nn.rs index 5c90dd4e..652c47a7 100644 --- a/candle-examples/examples/musicgen/nn.rs +++ b/candle-examples/examples/musicgen/nn.rs @@ -70,7 +70,7 @@ pub fn conv1d_weight_norm( ) -> Result<Conv1d> { let weight_g = vb.get((out_c, 1, 1), "weight_g")?; let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?; - let norm_v = (&weight_v * &weight_v)?.sum(&[1, 2])?.sqrt()?; + let norm_v = weight_v.sqr()?.sum_keepdim(&[1, 2])?.sqrt()?; let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?; let bias = vb.get(out_c, "bias")?; Ok(Conv1d::new(weight, Some(bias), config)) diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs index 0444f360..2119cf9b 100644 --- a/candle-examples/examples/musicgen/t5_model.rs +++ b/candle-examples/examples/musicgen/t5_model.rs @@ -98,7 +98,7 @@ impl T5LayerNorm { let dtype = xs.dtype(); let xs_f32 = xs.to_dtype(DType::F32)?; let xs2_f32 = (&xs_f32 * &xs_f32)?; - let sum_xs2_f32 = xs2_f32.sum(&[xs.rank() - 1])?; + let sum_xs2_f32 = xs2_f32.sum_keepdim(&[xs.rank() - 1])?; let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?; let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?; let xs = xs.to_dtype(dtype)?; diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs index 188a02bf..06f984f2 100644 --- a/candle-nn/src/layer_norm.rs +++ b/candle-nn/src/layer_norm.rs @@ -51,9 +51,9 @@ impl LayerNorm { }; let (_bsize, _seq_len, hidden_size) = x.shape().r3()?; let x = x.to_dtype(internal_dtype)?; - let mean_x = (x.sum(&[2])? / hidden_size as f64)?; + let mean_x = (x.sum_keepdim(&[2])? / hidden_size as f64)?; let x = x.broadcast_sub(&mean_x)?; - let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?; + let norm_x = (x.sqr()?.sum_keepdim(&[2])? / hidden_size as f64)?; let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; let x = x_normed .to_dtype(x_dtype)? diff --git a/candle-nn/tests/layer_norm.rs b/candle-nn/tests/layer_norm.rs index e4b962b4..fda7dbad 100644 --- a/candle-nn/tests/layer_norm.rs +++ b/candle-nn/tests/layer_norm.rs @@ -30,10 +30,10 @@ fn layer_norm() -> Result<()> { [4.1742344, 0.5, -3.1742344] ]] ); - let mean = (res.sum(&[2])? / 3.0)?; + let mean = (res.sum_keepdim(&[2])? / 3.0)?; // The average value should be `b`. assert_eq!(mean.to_vec3::<f32>()?, [[[0.5], [0.5], [0.5]]]); - let std = (res.broadcast_sub(&mean)?.sqr()?.sum(&[2])?.sqrt()? / 3.0)?; + let std = (res.broadcast_sub(&mean)?.sqr()?.sum_keepdim(&[2])?.sqrt()? / 3.0)?; // The standard deviation should be sqrt(`w`). assert_eq!( std.to_vec3::<f32>()?, diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 7cd361e4..136f8a4f 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -312,9 +312,11 @@ impl PyTensor { Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?)) } - fn sum(&self, dims: Vec<usize>) -> PyResult<Self> { + fn sum_keepdim(&self, dims: Vec<usize>) -> PyResult<Self> { // TODO: Support a single dim as input? - Ok(PyTensor(self.0.sum(dims.as_slice()).map_err(wrap_err)?)) + Ok(PyTensor( + self.0.sum_keepdim(dims.as_slice()).map_err(wrap_err)?, + )) } fn sum_all(&self) -> PyResult<Self> { |