summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/tensor.rs14
-rw-r--r--candle-core/tests/tensor_tests.rs17
2 files changed, 31 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index f032a896..d51a3db7 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -856,6 +856,20 @@ impl Tensor {
self.sum_impl(mean_dims, false)? * scale
}
+ /// Returns the unbiased variance over the selected dimension.
+ pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
+ let dim = dim.to_index(self.shape(), "var")?;
+ let mean = self.mean_keepdim(dim)?;
+ let squares = self.broadcast_sub(&mean)?.sqr()?;
+ squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
+ }
+
+ /// Returns the unbiased variance over the selected dimension.
+ pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
+ let dim = dim.to_index(self.shape(), "var")?;
+ self.var_keepdim(dim)?.squeeze(dim)
+ }
+
/// Gathers the maximum value across the selected dimension. The resulting shape has the same
/// number of dimensions as the original tensor and the select dimension has a single element.
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index 734cb7e8..cc44ce94 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -180,6 +180,22 @@ fn transpose(device: &Device) -> Result<()> {
Ok(())
}
+fn var(device: &Device) -> Result<()> {
+ // Values taken from https://pytorch.org/docs/stable/generated/torch.var.html
+ let data = &[
+ [0.2035f32, 1.2959, 1.8101, -0.4644],
+ [1.5027, -0.3270, 0.5905, 0.6538],
+ [-1.5745, 1.3330, -0.5596, -0.6548],
+ [0.1264, -0.5080, 1.6420, 0.1992],
+ ];
+ let tensor = Tensor::new(data, device)?;
+ assert_eq!(
+ test_utils::to_vec2_round(&tensor.var_keepdim(1)?, 4)?,
+ &[[1.0631], [0.559], [1.4893], [0.8258]]
+ );
+ Ok(())
+}
+
fn sum(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
@@ -1082,6 +1098,7 @@ test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
test_device!(randn, randn_cpu, randn_gpu);
test_device!(clamp, clamp_cpu, clamp_gpu);
+test_device!(var, var_cpu, var_gpu);
// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381