summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-09-13 06:08:36 +0100
committerGitHub <noreply@github.com>2024-09-13 07:08:36 +0200
commitc09afc211c7e6177223883de263733240e3210fe (patch)
tree3c1b52a4525eecf236c70c2f0b52a2d87abc0f56 /candle-metal-kernels
parentb60faebea4b39cdeceb949ac5db464e05983b153 (diff)
downloadcandle-c09afc211c7e6177223883de263733240e3210fe.tar.gz
candle-c09afc211c7e6177223883de263733240e3210fe.tar.bz2
candle-c09afc211c7e6177223883de263733240e3210fe.zip
Fix for metal tanh. (#2475)
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/unary.metal11
1 files changed, 8 insertions, 3 deletions
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal
index a82bfdbd..e3a18cfe 100644
--- a/candle-metal-kernels/src/unary.metal
+++ b/candle-metal-kernels/src/unary.metal
@@ -56,7 +56,7 @@ template <typename T> METAL_FUNC T gelu(T x) {
T x_cube = x_sq * x;
T alpha = x + static_cast<T>(0.044715) * x_cube;
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
- return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
+ return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(precise::tanh(beta)));
}
template <typename T> METAL_FUNC T relu(T in){
if (in < 0) {
@@ -154,7 +154,6 @@ UNARY_OP(floor)
UNARY_OP(round)
UNARY_OP(gelu_erf)
UNARY_OP(erf)
-UNARY_OP(tanh)
UNARY_OP(recip)
UNARY_OP(relu)
UNARY_OP(sign)
@@ -164,6 +163,11 @@ UNARY(id, half, copy_f16, copy_f16_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
+// tanh may create NaN on large values, e.g. 45 rather than outputing 1.
+// This has been an issue for the encodec example.
+UNARY(precise::tanh, float, tanh_f32, tanh_f32_strided);
+UNARY(precise::tanh, half, tanh_f16, tanh_f16_strided);
+
#if __METAL_VERSION__ >= 220
UNARY(id, int64_t, copy_i64, copy_i64_strided)
COPY2D(copy2d_i64, int64_t)
@@ -185,7 +189,6 @@ BFLOAT_UNARY_OP(floor)
BFLOAT_UNARY_OP(round)
BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf)
-BFLOAT_UNARY_OP(tanh)
BFLOAT_UNARY_OP(recip)
BFLOAT_UNARY_OP(relu)
BFLOAT_UNARY_OP(sign)
@@ -193,5 +196,7 @@ BFLOAT_UNARY_OP(sigmoid)
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
+UNARY(precise::tanh, bfloat, tanh_bf16, tanh_bf16_strided);
+
COPY2D(copy2d_bf16, bfloat)
#endif