diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-01-13 20:24:06 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-13 20:24:06 +0100 |
commit | e6d86b081980196745e5f0b0eda8ce5334c0ff67 (patch) | |
tree | f2680645ff85136d8504bde6f75e2a61cbee22f6 /candle-core/src/tensor.rs | |
parent | 88618255cb3c20b511a2f0e6db35d84081ce3c4a (diff) | |
download | candle-e6d86b081980196745e5f0b0eda8ce5334c0ff67.tar.gz candle-e6d86b081980196745e5f0b0eda8ce5334c0ff67.tar.bz2 candle-e6d86b081980196745e5f0b0eda8ce5334c0ff67.zip |
Add the pow operator. (#1583)
* Add the pow operator.
* Support the pow operation in onnx.
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 54f9fa2b..3100c6e8 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2578,11 +2578,21 @@ impl Tensor { } /// Returns log(sum(exp(tensor), dim)). - pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> { + pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> { let exp = self.exp()?; let sum = exp.sum(sum_dims)?; sum.log() } + + /// Pointwise pow operation. + pub fn pow(&self, rhs: &Tensor) -> Result<Self> { + rhs.mul(&self.log()?)?.exp() + } + + /// Broadcasting version of `pow`. + pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> { + rhs.broadcast_mul(&self.log()?)?.exp() + } } macro_rules! bin_trait { |