summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/tensor.rs22
-rw-r--r--candle-core/tests/tensor_tests.rs27
2 files changed, 49 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index f6b1698c..2a0924b6 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -2474,6 +2474,28 @@ impl Tensor {
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
t1.eq(&t2)?.to_dtype(dtype)
}
+
+ /// Returns the cumulative sum of elements of the input tensor summed over the specified
+ /// dimension.
+ ///
+ /// This operation is most efficient when dim is the last dimension of the tensor.
+ pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> {
+ let dim = dim.to_index(self.shape(), "cumsum")?;
+ let rank = self.rank();
+ if rank == 0 {
+ return Ok(self.clone());
+ }
+ let n_axis = self.dim(dim)?;
+ let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?;
+ if rank == 1 {
+ self.unsqueeze(0)?.matmul(&triu)?.squeeze(0)
+ } else {
+ let last = rank - 1;
+ let t = self.transpose(dim, last)?;
+ let t = t.broadcast_matmul(&triu)?;
+ t.transpose(dim, last)
+ }
+ }
}
macro_rules! bin_trait {
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index c8b255dd..f565972a 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -1169,3 +1169,30 @@ fn tril_triu_eye() -> Result<()> {
);
Ok(())
}
+
+#[test]
+fn cumsum() -> Result<()> {
+ let t = &[3f32, 1., 4., 1., 5.];
+ let t = Tensor::new(t, &Device::Cpu)?;
+ assert_eq!(t.cumsum(0)?.to_vec1::<f32>()?, [3., 4., 8., 9., 14.]);
+ let t = t.unsqueeze(1)?;
+ assert_eq!(
+ t.cumsum(0)?.to_vec2::<f32>()?,
+ [[3.0], [4.0], [8.0], [9.0], [14.0]]
+ );
+ assert_eq!(
+ t.cumsum(1)?.to_vec2::<f32>()?,
+ [[3.0], [1.0], [4.0], [1.0], [5.0]]
+ );
+ let t = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
+ let t = Tensor::new(t, &Device::Cpu)?;
+ assert_eq!(
+ t.cumsum(1)?.to_vec2::<f32>()?,
+ [[3.0, 4.0, 8.0, 9.0, 14.0], [2.0, 3.0, 10.0, 18.0, 20.0]],
+ );
+ assert_eq!(
+ t.cumsum(0)?.to_vec2::<f32>()?,
+ [[3.0, 1.0, 4.0, 1.0, 5.0], [5.0, 2.0, 11.0, 9.0, 7.0]]
+ );
+ Ok(())
+}