summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-11-14 21:08:04 +0100
committerGitHub <noreply@github.com>2024-11-14 21:08:04 +0100
commit0ed24b9852ccc7dfb92d555afba3d56c2a3f3224 (patch)
treea87ee556f325584284579c22f9c8ae293a7e08ad
parent06350c31c780d6ea485f506032aea6ff8809e38a (diff)
downloadcandle-0ed24b9852ccc7dfb92d555afba3d56c2a3f3224.tar.gz
candle-0ed24b9852ccc7dfb92d555afba3d56c2a3f3224.tar.bz2
candle-0ed24b9852ccc7dfb92d555afba3d56c2a3f3224.zip
Add max-all/min-all. (#2616)
-rw-r--r--candle-core/src/tensor.rs36
1 files changed, 36 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index e7355aad..75dc1c8a 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1760,6 +1760,42 @@ impl Tensor {
&self.op
}
+ /// Computes the max of all the elements in this tensor and returns a tensor holding this
+ /// scalar with zero dimensions.
+ ///
+ /// ```rust
+ /// use candle_core::{Tensor, Device};
+ /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
+ /// let tensor = tensor.max_all()?;
+ /// assert_eq!(tensor.to_scalar::<f32>()?, 5.);
+ /// # Ok::<(), candle_core::Error>(())
+ /// ```
+ pub fn max_all(&self) -> Result<Tensor> {
+ if self.rank() == 0 {
+ Ok(self.clone())
+ } else {
+ self.flatten_all()?.max(0)
+ }
+ }
+
+ /// Computes the min of all the elements in this tensor and returns a tensor holding this
+ /// scalar with zero dimensions.
+ ///
+ /// ```rust
+ /// use candle_core::{Tensor, Device};
+ /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
+ /// let tensor = tensor.min_all()?;
+ /// assert_eq!(tensor.to_scalar::<f32>()?, 0.);
+ /// # Ok::<(), candle_core::Error>(())
+ /// ```
+ pub fn min_all(&self) -> Result<Tensor> {
+ if self.rank() == 0 {
+ Ok(self.clone())
+ } else {
+ self.flatten_all()?.min(0)
+ }
+ }
+
/// Computes the sum of all the elements in this tensor and returns a tensor holding this
/// scalar with zero dimensions.
///