summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-13 21:32:32 +0100
committerGitHub <noreply@github.com>2023-07-13 21:32:32 +0100
commit2bfa791336b320b96d392aba83cbd4cee87173e3 (patch)
treea3127719a64cf5cfbf38f5f8be859afd2dc6118e
parent57be3638d8c10304629f6859d183fb192858f3a3 (diff)
downloadcandle-2bfa791336b320b96d392aba83cbd4cee87173e3.tar.gz
candle-2bfa791336b320b96d392aba83cbd4cee87173e3.tar.bz2
candle-2bfa791336b320b96d392aba83cbd4cee87173e3.zip
Use the same default as pytorch for sum. (#164)
-rw-r--r--candle-core/examples/cuda_basics.rs4
-rw-r--r--candle-core/examples/cuda_sum_benchmark.rs16
-rw-r--r--candle-core/src/backprop.rs6
-rw-r--r--candle-core/src/tensor.rs40
-rw-r--r--candle-core/tests/grad_tests.rs4
-rw-r--r--candle-core/tests/tensor_tests.rs81
-rw-r--r--candle-examples/examples/bert/main.rs8
-rw-r--r--candle-examples/examples/llama/model.rs2
-rw-r--r--candle-examples/examples/musicgen/nn.rs2
-rw-r--r--candle-examples/examples/musicgen/t5_model.rs2
-rw-r--r--candle-nn/src/layer_norm.rs4
-rw-r--r--candle-nn/tests/layer_norm.rs4
-rw-r--r--candle-pyo3/src/lib.rs6
13 files changed, 123 insertions, 56 deletions
diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs
index aeee541a..37a66cb5 100644
--- a/candle-core/examples/cuda_basics.rs
+++ b/candle-core/examples/cuda_basics.rs
@@ -7,9 +7,9 @@ use candle::{Device, Tensor};
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let t = Tensor::new(&[[1f32, 2., 3., 4.2]], &device)?;
- let sum = t.sum(&[0])?;
+ let sum = t.sum_keepdim(&[0])?;
println!("{sum}");
- let sum = t.sum(&[1])?;
+ let sum = t.sum_keepdim(&[1])?;
println!("{sum}");
Ok(())
}
diff --git a/candle-core/examples/cuda_sum_benchmark.rs b/candle-core/examples/cuda_sum_benchmark.rs
index 09d0099d..86a1691d 100644
--- a/candle-core/examples/cuda_sum_benchmark.rs
+++ b/candle-core/examples/cuda_sum_benchmark.rs
@@ -27,18 +27,18 @@ fn main() -> Result<()> {
let xys_cpu = cos_sin(n, &Device::Cpu)?;
let xys = cos_sin(n, &device)?;
println!("{xys_cpu:?} {xys:?}");
- let sum_cpu = xys_cpu.sum(&[1])?;
- println!("{sum_cpu}");
- let sum = xys.sum(&[1])?;
- println!("{sum}");
+ let sum_keepdim_cpu = xys_cpu.sum_keepdim(&[1])?;
+ println!("{sum_keepdim_cpu}");
+ let sum_keepdim = xys.sum_keepdim(&[1])?;
+ println!("{sum_keepdim}");
let start = std::time::Instant::now();
let n_iters = 100;
let mut v = 0f32;
for _i in 0..n_iters {
- let sum = xys.sum(&[1])?;
- let sum = sum.sum(&[0])?;
- let sum: f32 = sum.reshape(&[])?.to_scalar()?;
- v += sum;
+ let sum_keepdim = xys.sum_keepdim(&[1])?;
+ let sum_keepdim = sum_keepdim.sum_keepdim(&[0])?;
+ let sum_keepdim: f32 = sum_keepdim.reshape(&[])?.to_scalar()?;
+ v += sum_keepdim;
}
let elapsed = start.elapsed();
if v > 0. {
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 2711da85..c72f603f 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -195,11 +195,7 @@ impl Tensor {
}
}
- let mut arg_grad = grad.sum(sum_dims.as_slice())?;
- // sum_dims has increasing values.
- for &dim in sum_dims.iter().rev() {
- arg_grad = arg_grad.squeeze(dim)?
- }
+ let arg_grad = grad.sum(sum_dims.as_slice())?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.broadcast_add(&arg_grad)?
}
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 481a6851..af3675cc 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -572,7 +572,7 @@ impl Tensor {
// We do not have a cuda kernel for divide_by_sum_over_dim so split
// the operation.
let exp = self.exp()?;
- let sum_exp = exp.sum(&[dim])?;
+ let sum_exp = exp.sum_keepdim(&[dim])?;
exp.broadcast_div(&sum_exp)
} else {
let shape = self.shape();
@@ -591,21 +591,21 @@ impl Tensor {
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
/// input dimensions.
///
- /// The resulting tensor as a shape that is similar to the shape of the input tensor, except
+ /// The resulting tensor has a shape that is similar to the shape of the input tensor, except
/// that the number of elements for each dimension index in `sum_dims` is 1.
///
/// ```rust
/// use candle::{Tensor, Device};
/// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
- /// let s = a.sum(&[0])?;
+ /// let s = a.sum_keepdim(&[0])?;
/// assert_eq!(s.to_vec2::<f32>()?, &[[2., 4.]]);
- /// let s = a.sum(&[1])?;
+ /// let s = a.sum_keepdim(&[1])?;
/// assert_eq!(s.to_vec2::<f32>()?, &[[1.], [5.]]);
- /// let s = a.sum(&[0, 1])?;
+ /// let s = a.sum_keepdim(&[0, 1])?;
/// assert_eq!(s.to_vec2::<f32>()?, &[[6.]]);
/// # Ok::<(), candle::Error>(())
/// ```
- pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
+ pub fn sum_keepdim(&self, sum_dims: &[usize]) -> Result<Self> {
for &dim in sum_dims {
self.check_dim(dim, "sum")?;
}
@@ -622,6 +622,32 @@ impl Tensor {
Ok(from_storage(storage, dims, op, false))
}
+ /// Returns the sum of all elements in the input tensor. The sum is performed over all the
+ /// input dimensions and compared to `sum_keepdim` these dimensions are squeezed rather than
+ /// kept.
+ pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
+ let sum = self.sum_keepdim(sum_dims)?;
+ match sum_dims {
+ [] => Ok(sum),
+ [i] => sum.squeeze(*i),
+ sum_dims => {
+ let dims = sum
+ .dims()
+ .iter()
+ .enumerate()
+ .filter_map(|(dim_idx, &v)| {
+ if sum_dims.contains(&dim_idx) {
+ None
+ } else {
+ Some(v)
+ }
+ })
+ .collect::<Vec<_>>();
+ sum.reshape(dims)
+ }
+ }
+ }
+
/// Applies a 1D convolution over the input tensor.
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
let (c_out, c_in_k, k_size) = kernel.shape().r3()?;
@@ -936,7 +962,7 @@ impl Tensor {
/// ```
pub fn sum_all(&self) -> Result<Tensor> {
let dims: Vec<_> = (0..self.rank()).collect();
- self.sum(&dims)?.reshape(())
+ self.sum_keepdim(&dims)?.reshape(())
}
fn flatten_<D1: Dim, D2: Dim>(
diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs
index 501b0d44..241efc69 100644
--- a/candle-core/tests/grad_tests.rs
+++ b/candle-core/tests/grad_tests.rs
@@ -19,7 +19,7 @@ fn simple_grad(device: &Device) -> Result<()> {
fn sum_grad(device: &Device) -> Result<()> {
let x = Var::new(&[3f32, 1., 4.], device)?;
let x = x.as_tensor();
- let y = (x.sqr()?.sum(&[0])? * 2.)?;
+ let y = (x.sqr()?.sum_keepdim(&[0])? * 2.)?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [52.]);
@@ -27,7 +27,7 @@ fn sum_grad(device: &Device) -> Result<()> {
assert_eq!(grad_x.to_vec1::<f32>()?, &[12., 4., 16.]);
// Same test as before but squeezing on the last dimension.
- let y = (x.sqr()?.sum(&[0])? * 2.)?.squeeze(0)?;
+ let y = (x.sqr()?.sum_keepdim(&[0])? * 2.)?.squeeze(0)?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_scalar::<f32>()?, 52.);
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index 7e4467d1..b9e8a982 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -108,56 +108,99 @@ fn sum(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
assert_eq!(
- tensor.sum(&[2])?.to_vec3::<u32>()?,
+ tensor.sum_keepdim(&[2])?.to_vec3::<u32>()?,
&[[[8], [15]], [[10], [18]]]
);
assert_eq!(
- tensor.sum(&[0])?.to_vec3::<u32>()?,
+ tensor.sum_keepdim(&[0])?.to_vec3::<u32>()?,
&[[[5, 2, 11], [9, 7, 17]]],
);
- assert_eq!(tensor.sum(&[0, 2, 1])?.to_vec3::<u32>()?, &[[[51]]],);
+ assert_eq!(tensor.sum_keepdim(&[0, 2, 1])?.to_vec3::<u32>()?, &[[[51]]],);
assert_eq!(
- tensor.t()?.sum(&[1])?.t()?.to_vec3::<u32>()?,
+ tensor.t()?.sum_keepdim(&[1])?.t()?.to_vec3::<u32>()?,
&[[[8], [15]], [[10], [18]]]
);
assert_eq!(
- tensor.sum(&[2, 1])?.to_vec3::<u32>()?,
+ tensor.sum_keepdim(&[2, 1])?.to_vec3::<u32>()?,
&[[[8 + 15]], [[10 + 18]]]
);
let data: Vec<u32> = (0..4000u32).collect();
let tensor = Tensor::new(data.as_slice(), device)?;
- assert_eq!(tensor.sum(&[0])?.to_vec1::<u32>()?, &[7998000]);
+ assert_eq!(tensor.sum_keepdim(&[0])?.to_vec1::<u32>()?, &[7998000]);
let tensor = tensor.reshape((2000, 2))?;
- assert_eq!(tensor.sum(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]);
- assert_eq!(tensor.sum(&[0])?.sum(&[1])?.to_vec2::<u32>()?, &[[7998000]]);
- assert_eq!(tensor.sum(&[1])?.sum(&[0])?.to_vec2::<u32>()?, &[[7998000]]);
- assert_eq!(tensor.sum(&[0])?.to_vec2::<u32>()?, &[[3998000, 4000000]]);
+ assert_eq!(tensor.sum_keepdim(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]);
+ assert_eq!(
+ tensor
+ .sum_keepdim(&[0])?
+ .sum_keepdim(&[1])?
+ .to_vec2::<u32>()?,
+ &[[7998000]]
+ );
+ assert_eq!(
+ tensor
+ .sum_keepdim(&[1])?
+ .sum_keepdim(&[0])?
+ .to_vec2::<u32>()?,
+ &[[7998000]]
+ );
+ assert_eq!(
+ tensor.sum_keepdim(&[0])?.to_vec2::<u32>()?,
+ &[[3998000, 4000000]]
+ );
// Make the tensor non contiguous.
let tensor = tensor.t()?.contiguous()?.t()?;
- assert_eq!(tensor.sum(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]);
- assert_eq!(tensor.sum(&[0])?.sum(&[1])?.to_vec2::<u32>()?, &[[7998000]]);
- assert_eq!(tensor.sum(&[1])?.sum(&[0])?.to_vec2::<u32>()?, &[[7998000]]);
- assert_eq!(tensor.sum(&[0])?.to_vec2::<u32>()?, &[[3998000, 4000000]]);
+ assert_eq!(tensor.sum_keepdim(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]);
+ assert_eq!(
+ tensor
+ .sum_keepdim(&[0])?
+ .sum_keepdim(&[1])?
+ .to_vec2::<u32>()?,
+ &[[7998000]]
+ );
+ assert_eq!(
+ tensor
+ .sum_keepdim(&[1])?
+ .sum_keepdim(&[0])?
+ .to_vec2::<u32>()?,
+ &[[7998000]]
+ );
+ assert_eq!(
+ tensor.sum_keepdim(&[0])?.to_vec2::<u32>()?,
+ &[[3998000, 4000000]]
+ );
let t1 = tensor.reshape((200, 5, 4))?;
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
for tensor in [t1, t2] {
- assert_eq!(tensor.sum(&[0, 1, 2])?.to_vec3::<u32>()?, &[[[7998000]]]);
assert_eq!(
- tensor.sum(&[0])?.sum(&[2])?.sum(&[1])?.to_vec3::<u32>()?,
+ tensor.sum_keepdim(&[0, 1, 2])?.to_vec3::<u32>()?,
+ &[[[7998000]]]
+ );
+ assert_eq!(
+ tensor
+ .sum_keepdim(&[0])?
+ .sum_keepdim(&[2])?
+ .sum_keepdim(&[1])?
+ .to_vec3::<u32>()?,
&[[[7998000]]]
);
assert_eq!(
- tensor.sum(&[0])?.sum(&[1, 2])?.to_vec3::<u32>()?,
+ tensor
+ .sum_keepdim(&[0])?
+ .sum_keepdim(&[1, 2])?
+ .to_vec3::<u32>()?,
&[[[7998000]]]
);
assert_eq!(
- tensor.sum(&[1])?.sum(&[0, 2])?.to_vec3::<u32>()?,
+ tensor
+ .sum_keepdim(&[1])?
+ .sum_keepdim(&[0, 2])?
+ .to_vec3::<u32>()?,
&[[[7998000]]]
);
assert_eq!(
- tensor.sum(&[0])?.to_vec3::<u32>()?,
+ tensor.sum_keepdim(&[0])?.to_vec3::<u32>()?,
&[[
[398000, 398200, 398400, 398600],
[398800, 399000, 399200, 399400],
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs
index 1c3c429b..d7df5ae3 100644
--- a/candle-examples/examples/bert/main.rs
+++ b/candle-examples/examples/bert/main.rs
@@ -604,16 +604,16 @@ fn main() -> Result<()> {
println!("generated embeddings {:?}", embeddings.shape());
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.shape().r3()?;
- let embeddings = (embeddings.sum(&[1])? / (n_tokens as f64))?.squeeze(1)?;
+ let embeddings = (embeddings.sum_keepdim(&[1])? / (n_tokens as f64))?.squeeze(1)?;
println!("pooled embeddings {:?}", embeddings.shape());
let mut similarities = vec![];
for i in 0..n_sentences {
let e_i = embeddings.get(i)?;
for j in (i + 1)..n_sentences {
let e_j = embeddings.get(j)?;
- let sum_ij = (&e_i * &e_j)?.sum_all()?.reshape(())?.to_scalar::<f32>()?;
- let sum_i2 = (&e_i * &e_i)?.sum_all()?.reshape(())?.to_scalar::<f32>()?;
- let sum_j2 = (&e_j * &e_j)?.sum_all()?.reshape(())?.to_scalar::<f32>()?;
+ let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
+ let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
+ let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
similarities.push((cosine_similarity, i, j))
}
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs
index 04397d1e..57f339b0 100644
--- a/candle-examples/examples/llama/model.rs
+++ b/candle-examples/examples/llama/model.rs
@@ -95,7 +95,7 @@ impl RmsNorm {
// This is a no-op if x's dtype is already f32.
let x = x.to_dtype(DType::F32)?;
let (b_sz, seq_len, hidden_size) = x.shape().r3()?;
- let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
+ let norm_x = (x.sqr()?.sum_keepdim(&[2])? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
let size = self.scale.shape().r1()?;
diff --git a/candle-examples/examples/musicgen/nn.rs b/candle-examples/examples/musicgen/nn.rs
index 5c90dd4e..652c47a7 100644
--- a/candle-examples/examples/musicgen/nn.rs
+++ b/candle-examples/examples/musicgen/nn.rs
@@ -70,7 +70,7 @@ pub fn conv1d_weight_norm(
) -> Result<Conv1d> {
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
- let norm_v = (&weight_v * &weight_v)?.sum(&[1, 2])?.sqrt()?;
+ let norm_v = weight_v.sqr()?.sum_keepdim(&[1, 2])?.sqrt()?;
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
let bias = vb.get(out_c, "bias")?;
Ok(Conv1d::new(weight, Some(bias), config))
diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs
index 0444f360..2119cf9b 100644
--- a/candle-examples/examples/musicgen/t5_model.rs
+++ b/candle-examples/examples/musicgen/t5_model.rs
@@ -98,7 +98,7 @@ impl T5LayerNorm {
let dtype = xs.dtype();
let xs_f32 = xs.to_dtype(DType::F32)?;
let xs2_f32 = (&xs_f32 * &xs_f32)?;
- let sum_xs2_f32 = xs2_f32.sum(&[xs.rank() - 1])?;
+ let sum_xs2_f32 = xs2_f32.sum_keepdim(&[xs.rank() - 1])?;
let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?;
let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?;
let xs = xs.to_dtype(dtype)?;
diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs
index 188a02bf..06f984f2 100644
--- a/candle-nn/src/layer_norm.rs
+++ b/candle-nn/src/layer_norm.rs
@@ -51,9 +51,9 @@ impl LayerNorm {
};
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
let x = x.to_dtype(internal_dtype)?;
- let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
+ let mean_x = (x.sum_keepdim(&[2])? / hidden_size as f64)?;
let x = x.broadcast_sub(&mean_x)?;
- let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
+ let norm_x = (x.sqr()?.sum_keepdim(&[2])? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let x = x_normed
.to_dtype(x_dtype)?
diff --git a/candle-nn/tests/layer_norm.rs b/candle-nn/tests/layer_norm.rs
index e4b962b4..fda7dbad 100644
--- a/candle-nn/tests/layer_norm.rs
+++ b/candle-nn/tests/layer_norm.rs
@@ -30,10 +30,10 @@ fn layer_norm() -> Result<()> {
[4.1742344, 0.5, -3.1742344]
]]
);
- let mean = (res.sum(&[2])? / 3.0)?;
+ let mean = (res.sum_keepdim(&[2])? / 3.0)?;
// The average value should be `b`.
assert_eq!(mean.to_vec3::<f32>()?, [[[0.5], [0.5], [0.5]]]);
- let std = (res.broadcast_sub(&mean)?.sqr()?.sum(&[2])?.sqrt()? / 3.0)?;
+ let std = (res.broadcast_sub(&mean)?.sqr()?.sum_keepdim(&[2])?.sqrt()? / 3.0)?;
// The standard deviation should be sqrt(`w`).
assert_eq!(
std.to_vec3::<f32>()?,
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 7cd361e4..136f8a4f 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -312,9 +312,11 @@ impl PyTensor {
Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
}
- fn sum(&self, dims: Vec<usize>) -> PyResult<Self> {
+ fn sum_keepdim(&self, dims: Vec<usize>) -> PyResult<Self> {
// TODO: Support a single dim as input?
- Ok(PyTensor(self.0.sum(dims.as_slice()).map_err(wrap_err)?))
+ Ok(PyTensor(
+ self.0.sum_keepdim(dims.as_slice()).map_err(wrap_err)?,
+ ))
}
fn sum_all(&self) -> PyResult<Self> {