diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-12-09 19:46:36 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-09 19:46:36 +0100 |
commit | da0af3cb3e58d38476a20f4465744093a3b75dd4 (patch) | |
tree | c31b8520814719563b3d50eda66637dbf4b2e785 | |
parent | 803ac8405b49fbfc4e5aacca6e70f7955386df39 (diff) | |
parent | 6e25822d4fcd3321f1e078706683b990780ba1ae (diff) | |
download | candle-da0af3cb3e58d38476a20f4465744093a3b75dd4.tar.gz candle-da0af3cb3e58d38476a20f4465744093a3b75dd4.tar.bz2 candle-da0af3cb3e58d38476a20f4465744093a3b75dd4.zip |
Merge pull request #1408 from jbochi/metal_gelu2
Fix NaN errors for Gelu in Metal
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 23 | ||||
-rw-r--r-- | candle-metal-kernels/src/unary.metal | 11 |
2 files changed, 29 insertions, 5 deletions
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 59f54fa9..37b07167 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -206,6 +206,25 @@ fn cos_strided_random() { } #[test] +fn gelu_f16() { + let v: Vec<f16> = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let expected: Vec<f32> = vec![-0.0, -0.16, 0.0, 0.84, 1.96, 3.0, 10.0, 20.0]; + let results = run(&v, unary::contiguous::gelu::HALF); + assert_eq!(approx_f16(results, 2), expected); +} + +#[test] +fn gelu_f32() { + let v: Vec<f32> = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]; + let expected: Vec<f32> = vec![-0.0, -0.159, 0.0, 0.841, 1.955, 2.996, 10.0, 20.0]; + let results = run(&v, unary::contiguous::gelu::FLOAT); + assert_eq!(approx(results, 3), expected); +} + +#[test] fn binary_add_f32() { let left = vec![1.0f32, 2.0, 3.0]; let right = vec![2.0f32, 3.1, 4.2]; @@ -527,8 +546,8 @@ fn cos_f16() { .collect(); let results = run(&v, unary::contiguous::cos::HALF); let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect(); - assert_eq!(approx_f16(results, 4), vec![0.5405, -0.4163, -0.9902]); - assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]); + assert_eq!(approx_f16(results, 2), vec![0.54, -0.42, -0.99]); + assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]); } fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> { diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 88139af9..529162bd 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -42,9 +42,14 @@ template <typename T> METAL_FUNC T erf(T in){ return T(sign*y); } -template <typename T> METAL_FUNC T id(T in){ return in; } -template <typename T> METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); } -template <typename T> METAL_FUNC T gelu(T x){ +template <typename T> METAL_FUNC T id(T in) { return in; } +template <typename T> METAL_FUNC T gelu_erf(T x) { + return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); +} +template <typename T> METAL_FUNC T gelu(T x) { + if (x > 5) { + return x; + } T x_sq = x * x; T x_cube = x_sq * x; T alpha = x + static_cast<T>(0.044715) * x_cube; |