summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-01-13 20:24:06 +0100
committerGitHub <noreply@github.com>2024-01-13 20:24:06 +0100
commite6d86b081980196745e5f0b0eda8ce5334c0ff67 (patch)
treef2680645ff85136d8504bde6f75e2a61cbee22f6 /candle-core/src/tensor.rs
parent88618255cb3c20b511a2f0e6db35d84081ce3c4a (diff)
downloadcandle-e6d86b081980196745e5f0b0eda8ce5334c0ff67.tar.gz
candle-e6d86b081980196745e5f0b0eda8ce5334c0ff67.tar.bz2
candle-e6d86b081980196745e5f0b0eda8ce5334c0ff67.zip
Add the pow operator. (#1583)
* Add the pow operator. * Support the pow operation in onnx.
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs12
1 files changed, 11 insertions, 1 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 54f9fa2b..3100c6e8 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -2578,11 +2578,21 @@ impl Tensor {
}
/// Returns log(sum(exp(tensor), dim)).
- pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
+ pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
let exp = self.exp()?;
let sum = exp.sum(sum_dims)?;
sum.log()
}
+
+ /// Pointwise pow operation.
+ pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
+ rhs.mul(&self.log()?)?.exp()
+ }
+
+ /// Broadcasting version of `pow`.
+ pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
+ rhs.broadcast_mul(&self.log()?)?.exp()
+ }
}
macro_rules! bin_trait {