diff options
-rw-r--r-- | candle-core/src/tensor.rs | 22 | ||||
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 27 |
2 files changed, 49 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index f6b1698c..2a0924b6 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2474,6 +2474,28 @@ impl Tensor { let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?; t1.eq(&t2)?.to_dtype(dtype) } + + /// Returns the cumulative sum of elements of the input tensor summed over the specified + /// dimension. + /// + /// This operation is most efficient when dim is the last dimension of the tensor. + pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> { + let dim = dim.to_index(self.shape(), "cumsum")?; + let rank = self.rank(); + if rank == 0 { + return Ok(self.clone()); + } + let n_axis = self.dim(dim)?; + let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?; + if rank == 1 { + self.unsqueeze(0)?.matmul(&triu)?.squeeze(0) + } else { + let last = rank - 1; + let t = self.transpose(dim, last)?; + let t = t.broadcast_matmul(&triu)?; + t.transpose(dim, last) + } + } } macro_rules! bin_trait { diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index c8b255dd..f565972a 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1169,3 +1169,30 @@ fn tril_triu_eye() -> Result<()> { ); Ok(()) } + +#[test] +fn cumsum() -> Result<()> { + let t = &[3f32, 1., 4., 1., 5.]; + let t = Tensor::new(t, &Device::Cpu)?; + assert_eq!(t.cumsum(0)?.to_vec1::<f32>()?, [3., 4., 8., 9., 14.]); + let t = t.unsqueeze(1)?; + assert_eq!( + t.cumsum(0)?.to_vec2::<f32>()?, + [[3.0], [4.0], [8.0], [9.0], [14.0]] + ); + assert_eq!( + t.cumsum(1)?.to_vec2::<f32>()?, + [[3.0], [1.0], [4.0], [1.0], [5.0]] + ); + let t = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]]; + let t = Tensor::new(t, &Device::Cpu)?; + assert_eq!( + t.cumsum(1)?.to_vec2::<f32>()?, + [[3.0, 4.0, 8.0, 9.0, 14.0], [2.0, 3.0, 10.0, 18.0, 20.0]], + ); + assert_eq!( + t.cumsum(0)?.to_vec2::<f32>()?, + [[3.0, 1.0, 4.0, 1.0, 5.0], [5.0, 2.0, 11.0, 9.0, 7.0]] + ); + Ok(()) +} |