summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/tests/tensor_tests.rs13
-rw-r--r--candle-metal-kernels/src/affine.metal2
2 files changed, 14 insertions, 1 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index 567b49f1..4a76035c 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -193,6 +193,19 @@ fn unary_op(device: &Device) -> Result<()> {
tensor.sign()?.to_vec1::<f32>()?,
[-1., -1., -1., 0., 0., 1., 1., 1., 1.]
);
+ let tensor = Tensor::new(&[-1.0f32, 0., -2., 3.], device)?;
+ let y = tensor.elu(2.)?;
+ assert_eq!(
+ test_utils::to_vec1_round(&y, 4)?,
+ [-1.2642, 0.0000, -1.7293, 3.0000]
+ );
+ // This test failed on metal prior to the following PR:
+ // https://github.com/huggingface/candle/pull/2490
+ let y = tensor.reshape((2, 2))?.t()?.elu(2.)?.flatten_all()?;
+ assert_eq!(
+ test_utils::to_vec1_round(&y, 4)?,
+ [-1.2642, -1.7293, 0.0000, 3.0000]
+ );
Ok(())
}
diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal
index cbbb03e2..e5229f55 100644
--- a/candle-metal-kernels/src/affine.metal
+++ b/candle-metal-kernels/src/affine.metal
@@ -105,7 +105,7 @@ kernel void FN_NAME##_strided( \
return; \
} \
const TYPENAME x = input[get_strided_index(id, num_dims, dims, strides)]; \
- output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \
+ output[id] = TYPENAME((x > 0)?x: mul * (exp(x) - 1)); \
} \