diff options
author | Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> | 2024-01-14 17:10:54 +0100 |
---|---|---|
committer | Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> | 2024-01-14 17:10:54 +0100 |
commit | ecf88a6d381e40c8db1c643dff2753fd877fae92 (patch) | |
tree | 9b6db2ec9a37a48185f323ab4c5e8b0baaa20221 /candle-core/src/tensor.rs | |
parent | e06e8d0dbea3a052195f4ca27fb5ddcdbf1cd30c (diff) | |
parent | e6d86b081980196745e5f0b0eda8ce5334c0ff67 (diff) | |
download | candle-ecf88a6d381e40c8db1c643dff2753fd877fae92.tar.gz candle-ecf88a6d381e40c8db1c643dff2753fd877fae92.tar.bz2 candle-ecf88a6d381e40c8db1c643dff2753fd877fae92.zip |
Merge branch 'main' into ivarflakstad/metal-prng
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 54f9fa2b..3100c6e8 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2578,11 +2578,21 @@ impl Tensor { } /// Returns log(sum(exp(tensor), dim)). - pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> { + pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> { let exp = self.exp()?; let sum = exp.sum(sum_dims)?; sum.log() } + + /// Pointwise pow operation. + pub fn pow(&self, rhs: &Tensor) -> Result<Self> { + rhs.mul(&self.log()?)?.exp() + } + + /// Broadcasting version of `pow`. + pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> { + rhs.broadcast_mul(&self.log()?)?.exp() + } } macro_rules! bin_trait { |