summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/tensor.rs36
1 files changed, 36 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index e7355aad..75dc1c8a 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1760,6 +1760,42 @@ impl Tensor {
&self.op
}
+ /// Computes the max of all the elements in this tensor and returns a tensor holding this
+ /// scalar with zero dimensions.
+ ///
+ /// ```rust
+ /// use candle_core::{Tensor, Device};
+ /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
+ /// let tensor = tensor.max_all()?;
+ /// assert_eq!(tensor.to_scalar::<f32>()?, 5.);
+ /// # Ok::<(), candle_core::Error>(())
+ /// ```
+ pub fn max_all(&self) -> Result<Tensor> {
+ if self.rank() == 0 {
+ Ok(self.clone())
+ } else {
+ self.flatten_all()?.max(0)
+ }
+ }
+
+ /// Computes the min of all the elements in this tensor and returns a tensor holding this
+ /// scalar with zero dimensions.
+ ///
+ /// ```rust
+ /// use candle_core::{Tensor, Device};
+ /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
+ /// let tensor = tensor.min_all()?;
+ /// assert_eq!(tensor.to_scalar::<f32>()?, 0.);
+ /// # Ok::<(), candle_core::Error>(())
+ /// ```
+ pub fn min_all(&self) -> Result<Tensor> {
+ if self.rank() == 0 {
+ Ok(self.clone())
+ } else {
+ self.flatten_all()?.min(0)
+ }
+ }
+
/// Computes the sum of all the elements in this tensor and returns a tensor holding this
/// scalar with zero dimensions.
///