summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-12-09 19:46:36 +0100
committerGitHub <noreply@github.com>2023-12-09 19:46:36 +0100
commitda0af3cb3e58d38476a20f4465744093a3b75dd4 (patch)
treec31b8520814719563b3d50eda66637dbf4b2e785
parent803ac8405b49fbfc4e5aacca6e70f7955386df39 (diff)
parent6e25822d4fcd3321f1e078706683b990780ba1ae (diff)
downloadcandle-da0af3cb3e58d38476a20f4465744093a3b75dd4.tar.gz
candle-da0af3cb3e58d38476a20f4465744093a3b75dd4.tar.bz2
candle-da0af3cb3e58d38476a20f4465744093a3b75dd4.zip
Merge pull request #1408 from jbochi/metal_gelu2
Fix NaN errors for Gelu in Metal
-rw-r--r--candle-metal-kernels/src/tests.rs23
-rw-r--r--candle-metal-kernels/src/unary.metal11
2 files changed, 29 insertions, 5 deletions
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index 59f54fa9..37b07167 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -206,6 +206,25 @@ fn cos_strided_random() {
}
#[test]
+fn gelu_f16() {
+ let v: Vec<f16> = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]
+ .iter()
+ .map(|v| f16::from_f32(*v))
+ .collect();
+ let expected: Vec<f32> = vec![-0.0, -0.16, 0.0, 0.84, 1.96, 3.0, 10.0, 20.0];
+ let results = run(&v, unary::contiguous::gelu::HALF);
+ assert_eq!(approx_f16(results, 2), expected);
+}
+
+#[test]
+fn gelu_f32() {
+ let v: Vec<f32> = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0];
+ let expected: Vec<f32> = vec![-0.0, -0.159, 0.0, 0.841, 1.955, 2.996, 10.0, 20.0];
+ let results = run(&v, unary::contiguous::gelu::FLOAT);
+ assert_eq!(approx(results, 3), expected);
+}
+
+#[test]
fn binary_add_f32() {
let left = vec![1.0f32, 2.0, 3.0];
let right = vec![2.0f32, 3.1, 4.2];
@@ -527,8 +546,8 @@ fn cos_f16() {
.collect();
let results = run(&v, unary::contiguous::cos::HALF);
let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
- assert_eq!(approx_f16(results, 4), vec![0.5405, -0.4163, -0.9902]);
- assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]);
+ assert_eq!(approx_f16(results, 2), vec![0.54, -0.42, -0.99]);
+ assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]);
}
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal
index 88139af9..529162bd 100644
--- a/candle-metal-kernels/src/unary.metal
+++ b/candle-metal-kernels/src/unary.metal
@@ -42,9 +42,14 @@ template <typename T> METAL_FUNC T erf(T in){
return T(sign*y);
}
-template <typename T> METAL_FUNC T id(T in){ return in; }
-template <typename T> METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); }
-template <typename T> METAL_FUNC T gelu(T x){
+template <typename T> METAL_FUNC T id(T in) { return in; }
+template <typename T> METAL_FUNC T gelu_erf(T x) {
+ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2);
+}
+template <typename T> METAL_FUNC T gelu(T x) {
+ if (x > 5) {
+ return x;
+ }
T x_sq = x * x;
T x_cube = x_sq * x;
T alpha = x + static_cast<T>(0.044715) * x_cube;