diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-09-13 06:08:36 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-13 07:08:36 +0200 |
commit | c09afc211c7e6177223883de263733240e3210fe (patch) | |
tree | 3c1b52a4525eecf236c70c2f0b52a2d87abc0f56 /candle-metal-kernels | |
parent | b60faebea4b39cdeceb949ac5db464e05983b153 (diff) | |
download | candle-c09afc211c7e6177223883de263733240e3210fe.tar.gz candle-c09afc211c7e6177223883de263733240e3210fe.tar.bz2 candle-c09afc211c7e6177223883de263733240e3210fe.zip |
Fix for metal tanh. (#2475)
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/unary.metal | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index a82bfdbd..e3a18cfe 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -56,7 +56,7 @@ template <typename T> METAL_FUNC T gelu(T x) { T x_cube = x_sq * x; T alpha = x + static_cast<T>(0.044715) * x_cube; T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha); - return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta))); + return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(precise::tanh(beta))); } template <typename T> METAL_FUNC T relu(T in){ if (in < 0) { @@ -154,7 +154,6 @@ UNARY_OP(floor) UNARY_OP(round) UNARY_OP(gelu_erf) UNARY_OP(erf) -UNARY_OP(tanh) UNARY_OP(recip) UNARY_OP(relu) UNARY_OP(sign) @@ -164,6 +163,11 @@ UNARY(id, half, copy_f16, copy_f16_strided) UNARY(id, uint8_t, copy_u8, copy_u8_strided) UNARY(id, uint32_t, copy_u32, copy_u32_strided) +// tanh may create NaN on large values, e.g. 45 rather than outputing 1. +// This has been an issue for the encodec example. +UNARY(precise::tanh, float, tanh_f32, tanh_f32_strided); +UNARY(precise::tanh, half, tanh_f16, tanh_f16_strided); + #if __METAL_VERSION__ >= 220 UNARY(id, int64_t, copy_i64, copy_i64_strided) COPY2D(copy2d_i64, int64_t) @@ -185,7 +189,6 @@ BFLOAT_UNARY_OP(floor) BFLOAT_UNARY_OP(round) BFLOAT_UNARY_OP(gelu_erf) BFLOAT_UNARY_OP(erf) -BFLOAT_UNARY_OP(tanh) BFLOAT_UNARY_OP(recip) BFLOAT_UNARY_OP(relu) BFLOAT_UNARY_OP(sign) @@ -193,5 +196,7 @@ BFLOAT_UNARY_OP(sigmoid) UNARY(id, bfloat, copy_bf16, copy_bf16_strided) +UNARY(precise::tanh, bfloat, tanh_bf16, tanh_bf16_strided); + COPY2D(copy2d_bf16, bfloat) #endif |