summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorJuarez Bochi <jbochi@gmail.com>2024-01-10 12:27:17 -0500
committerGitHub <noreply@github.com>2024-01-10 18:27:17 +0100
commitae06cb74bb132913b4777cd119b915e665c013bb (patch)
tree55b679ad5f4ea42bb5d863ccca6483c925751dc7 /candle-metal-kernels
parenta897fda74e372ff0e08c86a5468124b51f5941a7 (diff)
downloadcandle-ae06cb74bb132913b4777cd119b915e665c013bb.tar.gz
candle-ae06cb74bb132913b4777cd119b915e665c013bb.tar.bz2
candle-ae06cb74bb132913b4777cd119b915e665c013bb.zip
Add relu kernel for metal (#1488)
* Add relu kernel for metal * Copy error messages proposed in #1491 * Revert non relu changes * Fix name changes * Fix the last of us (: * Fix copy and paste mistakes * Fix typo * Revert order changes * Revert order change * Add deleted functions back * Run rustfmt
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs4
-rw-r--r--candle-metal-kernels/src/unary.metal8
2 files changed, 10 insertions, 2 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 5d34f61a..c872dc60 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -174,8 +174,8 @@ macro_rules! ops{
pub mod unary {
ops!(
- cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, round, erf, gelu_erf, tanh,
- recip
+ cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
+ tanh, recip
);
}
pub mod binary {
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal
index 7fbb613d..f95f6ba9 100644
--- a/candle-metal-kernels/src/unary.metal
+++ b/candle-metal-kernels/src/unary.metal
@@ -58,6 +58,12 @@ template <typename T> METAL_FUNC T gelu(T x) {
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
}
+template <typename T> METAL_FUNC T relu(T in){
+ if (in < 0) {
+ return 0;
+ }
+ return in;
+}
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
@@ -110,6 +116,7 @@ UNARY_OP(gelu_erf)
UNARY_OP(erf)
UNARY_OP(tanh)
UNARY_OP(recip)
+UNARY_OP(relu)
UNARY(id, float, copy_f32, copy_f32_strided)
UNARY(id, half, copy_f16, copy_f16_strided)
@@ -136,6 +143,7 @@ BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf)
BFLOAT_UNARY_OP(tanh)
BFLOAT_UNARY_OP(recip)
+BFLOAT_UNARY_OP(relu)
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
#endif