summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-11-15 21:11:15 +0000
committerGitHub <noreply@github.com>2023-11-15 21:11:15 +0000
commitc6763e3b41790bea0459d3838c3ed1d59d5174ef (patch)
treea1c255b0b8d318c138ae403113e66b4bf973a78f /candle-core/src/tensor.rs
parent347e31c9ff08f52574c0158e2a48cab52e224b4d (diff)
downloadcandle-c6763e3b41790bea0459d3838c3ed1d59d5174ef.tar.gz
candle-c6763e3b41790bea0459d3838c3ed1d59d5174ef.tar.bz2
candle-c6763e3b41790bea0459d3838c3ed1d59d5174ef.zip
Add a simple implementation of cumsum. (#1334)
* Add a simple implementation of cumsum. * Add another test.
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs22
1 files changed, 22 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index f6b1698c..2a0924b6 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -2474,6 +2474,28 @@ impl Tensor {
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
t1.eq(&t2)?.to_dtype(dtype)
}
+
+ /// Returns the cumulative sum of elements of the input tensor summed over the specified
+ /// dimension.
+ ///
+ /// This operation is most efficient when dim is the last dimension of the tensor.
+ pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> {
+ let dim = dim.to_index(self.shape(), "cumsum")?;
+ let rank = self.rank();
+ if rank == 0 {
+ return Ok(self.clone());
+ }
+ let n_axis = self.dim(dim)?;
+ let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?;
+ if rank == 1 {
+ self.unsqueeze(0)?.matmul(&triu)?.squeeze(0)
+ } else {
+ let last = rank - 1;
+ let t = self.transpose(dim, last)?;
+ let t = t.broadcast_matmul(&triu)?;
+ t.transpose(dim, last)
+ }
+ }
}
macro_rules! bin_trait {