diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-21 15:01:38 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-21 15:01:38 +0100 |
commit | e8f760ee44ad4b1f9f3606e36a1966df8509203b (patch) | |
tree | 7e111718cd8c4ea432a6ae4ef28c141be2f2df9e | |
parent | 94e3373883caaa7442201dac25abe16b4469f9bd (diff) | |
download | candle-e8f760ee44ad4b1f9f3606e36a1966df8509203b.tar.gz candle-e8f760ee44ad4b1f9f3606e36a1966df8509203b.tar.bz2 candle-e8f760ee44ad4b1f9f3606e36a1966df8509203b.zip |
Add get_on_dim. (#1142)
-rw-r--r-- | candle-core/src/tensor.rs | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index da47d180..0ffed2fe 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1598,6 +1598,24 @@ impl Tensor { } } + /// Returns the sub-tensor fixing the index at `index` on the dimension `dim`. + /// + /// ```rust + /// use candle_core::{Tensor, Device}; + /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let t = tensor.get_on_dim(1, 0)?; + /// assert_eq!(t.to_vec1::<f32>()?, &[0., 2., 4.]); + /// let t = tensor.get_on_dim(1, 1)?; + /// assert_eq!(t.to_vec1::<f32>()?, &[1., 3., 5.]); + /// let t = tensor.get_on_dim(0, 1)?; + /// assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]); + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub fn get_on_dim<D: Dim>(&self, dim: D, index: usize) -> Result<Tensor> { + let dim = dim.to_index(self.shape(), "get_on_dim")?; + self.narrow(dim, index, 1)?.squeeze(dim) + } + /// Returns a tensor that is a transposed version of the input, the two last dimensions of the /// input are swapped. /// |