diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-11-15 21:11:15 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-15 21:11:15 +0000 |
commit | c6763e3b41790bea0459d3838c3ed1d59d5174ef (patch) | |
tree | a1c255b0b8d318c138ae403113e66b4bf973a78f /candle-core/src/tensor.rs | |
parent | 347e31c9ff08f52574c0158e2a48cab52e224b4d (diff) | |
download | candle-c6763e3b41790bea0459d3838c3ed1d59d5174ef.tar.gz candle-c6763e3b41790bea0459d3838c3ed1d59d5174ef.tar.bz2 candle-c6763e3b41790bea0459d3838c3ed1d59d5174ef.zip |
Add a simple implementation of cumsum. (#1334)
* Add a simple implementation of cumsum.
* Add another test.
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index f6b1698c..2a0924b6 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2474,6 +2474,28 @@ impl Tensor { let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?; t1.eq(&t2)?.to_dtype(dtype) } + + /// Returns the cumulative sum of elements of the input tensor summed over the specified + /// dimension. + /// + /// This operation is most efficient when dim is the last dimension of the tensor. + pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> { + let dim = dim.to_index(self.shape(), "cumsum")?; + let rank = self.rank(); + if rank == 0 { + return Ok(self.clone()); + } + let n_axis = self.dim(dim)?; + let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?; + if rank == 1 { + self.unsqueeze(0)?.matmul(&triu)?.squeeze(0) + } else { + let last = rank - 1; + let t = self.transpose(dim, last)?; + let t = t.broadcast_matmul(&triu)?; + t.transpose(dim, last) + } + } } macro_rules! bin_trait { |