diff options
author | Juarez Bochi <jbochi@gmail.com> | 2024-01-10 12:27:17 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-10 18:27:17 +0100 |
commit | ae06cb74bb132913b4777cd119b915e665c013bb (patch) | |
tree | 55b679ad5f4ea42bb5d863ccca6483c925751dc7 /candle-metal-kernels | |
parent | a897fda74e372ff0e08c86a5468124b51f5941a7 (diff) | |
download | candle-ae06cb74bb132913b4777cd119b915e665c013bb.tar.gz candle-ae06cb74bb132913b4777cd119b915e665c013bb.tar.bz2 candle-ae06cb74bb132913b4777cd119b915e665c013bb.zip |
Add relu kernel for metal (#1488)
* Add relu kernel for metal
* Copy error messages proposed in #1491
* Revert non relu changes
* Fix name changes
* Fix the last of us (:
* Fix copy and paste mistakes
* Fix typo
* Revert order changes
* Revert order change
* Add deleted functions back
* Run rustfmt
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 4 | ||||
-rw-r--r-- | candle-metal-kernels/src/unary.metal | 8 |
2 files changed, 10 insertions, 2 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 5d34f61a..c872dc60 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -174,8 +174,8 @@ macro_rules! ops{ pub mod unary { ops!( - cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, round, erf, gelu_erf, tanh, - recip + cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf, + tanh, recip ); } pub mod binary { diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 7fbb613d..f95f6ba9 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -58,6 +58,12 @@ template <typename T> METAL_FUNC T gelu(T x) { T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha); return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta))); } +template <typename T> METAL_FUNC T relu(T in){ + if (in < 0) { + return 0; + } + return in; +} #define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ kernel void FN_NAME( \ @@ -110,6 +116,7 @@ UNARY_OP(gelu_erf) UNARY_OP(erf) UNARY_OP(tanh) UNARY_OP(recip) +UNARY_OP(relu) UNARY(id, float, copy_f32, copy_f32_strided) UNARY(id, half, copy_f16, copy_f16_strided) @@ -136,6 +143,7 @@ BFLOAT_UNARY_OP(gelu_erf) BFLOAT_UNARY_OP(erf) BFLOAT_UNARY_OP(tanh) BFLOAT_UNARY_OP(recip) +BFLOAT_UNARY_OP(relu) UNARY(id, bfloat, copy_bf16, copy_bf16_strided) #endif |