diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-11-14 21:08:04 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-11-14 21:08:04 +0100 |
commit | 0ed24b9852ccc7dfb92d555afba3d56c2a3f3224 (patch) | |
tree | a87ee556f325584284579c22f9c8ae293a7e08ad | |
parent | 06350c31c780d6ea485f506032aea6ff8809e38a (diff) | |
download | candle-0ed24b9852ccc7dfb92d555afba3d56c2a3f3224.tar.gz candle-0ed24b9852ccc7dfb92d555afba3d56c2a3f3224.tar.bz2 candle-0ed24b9852ccc7dfb92d555afba3d56c2a3f3224.zip |
Add max-all/min-all. (#2616)
-rw-r--r-- | candle-core/src/tensor.rs | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index e7355aad..75dc1c8a 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1760,6 +1760,42 @@ impl Tensor { &self.op } + /// Computes the max of all the elements in this tensor and returns a tensor holding this + /// scalar with zero dimensions. + /// + /// ```rust + /// use candle_core::{Tensor, Device}; + /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let tensor = tensor.max_all()?; + /// assert_eq!(tensor.to_scalar::<f32>()?, 5.); + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub fn max_all(&self) -> Result<Tensor> { + if self.rank() == 0 { + Ok(self.clone()) + } else { + self.flatten_all()?.max(0) + } + } + + /// Computes the min of all the elements in this tensor and returns a tensor holding this + /// scalar with zero dimensions. + /// + /// ```rust + /// use candle_core::{Tensor, Device}; + /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let tensor = tensor.min_all()?; + /// assert_eq!(tensor.to_scalar::<f32>()?, 0.); + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub fn min_all(&self) -> Result<Tensor> { + if self.rank() == 0 { + Ok(self.clone()) + } else { + self.flatten_all()?.min(0) + } + } + /// Computes the sum of all the elements in this tensor and returns a tensor holding this /// scalar with zero dimensions. /// |