summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/tensor.rs12
-rw-r--r--candle-core/tests/tensor_tests.rs16
-rw-r--r--candle-onnx/src/eval.rs6
3 files changed, 31 insertions, 3 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 54f9fa2b..3100c6e8 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -2578,11 +2578,21 @@ impl Tensor {
}
/// Returns log(sum(exp(tensor), dim)).
- pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
+ pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
let exp = self.exp()?;
let sum = exp.sum(sum_dims)?;
sum.log()
}
+
+ /// Pointwise pow operation.
+ pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
+ rhs.mul(&self.log()?)?.exp()
+ }
+
+ /// Broadcasting version of `pow`.
+ pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
+ rhs.broadcast_mul(&self.log()?)?.exp()
+ }
}
macro_rules! bin_trait {
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index e83fb55b..33bab1b6 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -1245,11 +1245,23 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
}
#[test]
-fn logsumexp() -> Result<()> {
+fn log_sum_exp() -> Result<()> {
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
- let output = input.logsumexp(D::Minus1)?;
+ let output = input.log_sum_exp(D::Minus1)?;
// The expectations obtained from pytorch.
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
assert_close(&output, &expected, 0.00001)?;
Ok(())
}
+
+#[test]
+fn pow() -> Result<()> {
+ let lhs = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
+ let rhs = (&lhs - 2.)?;
+ let res = lhs.pow(&rhs)?;
+ assert_eq!(
+ test_utils::to_vec2_round(&res, 4)?,
+ [[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
+ );
+ Ok(())
+}
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs
index 684776c2..c0ad8668 100644
--- a/candle-onnx/src/eval.rs
+++ b/candle-onnx/src/eval.rs
@@ -254,6 +254,12 @@ pub fn simple_eval(
let output = input0.broadcast_div(input1)?;
values.insert(node.output[0].clone(), output);
}
+ "Pow" => {
+ let input0 = get(&node.input[0])?;
+ let input1 = get(&node.input[1])?;
+ let output = input0.broadcast_pow(input1)?;
+ values.insert(node.output[0].clone(), output);
+ }
"Equal" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;