diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-11-15 21:11:15 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-15 21:11:15 +0000 |
commit | c6763e3b41790bea0459d3838c3ed1d59d5174ef (patch) | |
tree | a1c255b0b8d318c138ae403113e66b4bf973a78f | |
parent | 347e31c9ff08f52574c0158e2a48cab52e224b4d (diff) | |
download | candle-c6763e3b41790bea0459d3838c3ed1d59d5174ef.tar.gz candle-c6763e3b41790bea0459d3838c3ed1d59d5174ef.tar.bz2 candle-c6763e3b41790bea0459d3838c3ed1d59d5174ef.zip |
Add a simple implementation of cumsum. (#1334)
* Add a simple implementation of cumsum.
* Add another test.
-rw-r--r-- | candle-core/src/tensor.rs | 22 | ||||
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 27 |
2 files changed, 49 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index f6b1698c..2a0924b6 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2474,6 +2474,28 @@ impl Tensor { let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?; t1.eq(&t2)?.to_dtype(dtype) } + + /// Returns the cumulative sum of elements of the input tensor summed over the specified + /// dimension. + /// + /// This operation is most efficient when dim is the last dimension of the tensor. + pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> { + let dim = dim.to_index(self.shape(), "cumsum")?; + let rank = self.rank(); + if rank == 0 { + return Ok(self.clone()); + } + let n_axis = self.dim(dim)?; + let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?; + if rank == 1 { + self.unsqueeze(0)?.matmul(&triu)?.squeeze(0) + } else { + let last = rank - 1; + let t = self.transpose(dim, last)?; + let t = t.broadcast_matmul(&triu)?; + t.transpose(dim, last) + } + } } macro_rules! bin_trait { diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index c8b255dd..f565972a 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1169,3 +1169,30 @@ fn tril_triu_eye() -> Result<()> { ); Ok(()) } + +#[test] +fn cumsum() -> Result<()> { + let t = &[3f32, 1., 4., 1., 5.]; + let t = Tensor::new(t, &Device::Cpu)?; + assert_eq!(t.cumsum(0)?.to_vec1::<f32>()?, [3., 4., 8., 9., 14.]); + let t = t.unsqueeze(1)?; + assert_eq!( + t.cumsum(0)?.to_vec2::<f32>()?, + [[3.0], [4.0], [8.0], [9.0], [14.0]] + ); + assert_eq!( + t.cumsum(1)?.to_vec2::<f32>()?, + [[3.0], [1.0], [4.0], [1.0], [5.0]] + ); + let t = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]]; + let t = Tensor::new(t, &Device::Cpu)?; + assert_eq!( + t.cumsum(1)?.to_vec2::<f32>()?, + [[3.0, 4.0, 8.0, 9.0, 14.0], [2.0, 3.0, 10.0, 18.0, 20.0]], + ); + assert_eq!( + t.cumsum(0)?.to_vec2::<f32>()?, + [[3.0, 1.0, 4.0, 1.0, 5.0], [5.0, 2.0, 11.0, 9.0, 7.0]] + ); + Ok(()) +} |