diff options
author | OlivierDehaene <Olivier.dehaene@gmail.com> | 2024-02-14 10:27:22 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-14 10:27:22 +0100 |
commit | b60064780d09ab6733f5287b322ea5cb057d3136 (patch) | |
tree | 38bd95dc351a046dd5b7c67fdf9203a32b06e38b /candle-metal-kernels | |
parent | 14010a8498af3383b004be1f55a2fa39bce5389d (diff) | |
download | candle-b60064780d09ab6733f5287b322ea5cb057d3136.tar.gz candle-b60064780d09ab6733f5287b322ea5cb057d3136.tar.bz2 candle-b60064780d09ab6733f5287b322ea5cb057d3136.zip |
feat: add silu activation function (#1706)
* feat: add silu activation function
* use silu/arg in grad
* update candle-nn
* use node
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 2 | ||||
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 19 | ||||
-rw-r--r-- | candle-metal-kernels/src/unary.metal | 5 |
3 files changed, 25 insertions, 1 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 2d27d230..33bc3453 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -183,7 +183,7 @@ macro_rules! ops{ pub mod unary { ops!( cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf, - tanh, recip + tanh, recip, silu ); } pub mod binary { diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 655161e5..459c8edb 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -232,6 +232,25 @@ fn gelu_f32() { } #[test] +fn silu_f16() { + let v: Vec<f16> = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let expected: Vec<f32> = vec![-0.0, -0.27, 0.0, 0.73, 1.76, 2.86, 10.0, 20.0]; + let results = run(&v, unary::contiguous::silu::HALF); + assert_eq!(approx_f16(results, 2), expected); +} + +#[test] +fn silu_f32() { + let v: Vec<f32> = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]; + let expected: Vec<f32> = vec![-0.0, -0.269, 0.0, 0.731, 1.762, 2.858, 10.0, 20.0]; + let results = run(&v, unary::contiguous::silu::FLOAT); + assert_eq!(approx(results, 3), expected); +} + +#[test] fn binary_add_f32() { let left = vec![1.0f32, 2.0, 3.0]; let right = vec![2.0f32, 3.1, 4.2]; diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 7add58fd..1e0d5526 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -64,6 +64,9 @@ template <typename T> METAL_FUNC T relu(T in){ } return in; } +template <typename T> METAL_FUNC T silu(T in){ + return in / (static_cast<T>(1) + exp(-in)); +} #define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ kernel void FN_NAME( \ @@ -108,6 +111,7 @@ UNARY_OP(neg) UNARY_OP(exp) UNARY_OP(log) UNARY_OP(gelu) +UNARY_OP(silu) UNARY_OP(abs) UNARY_OP(ceil) UNARY_OP(floor) @@ -135,6 +139,7 @@ BFLOAT_UNARY_OP(neg) BFLOAT_UNARY_OP(exp) BFLOAT_UNARY_OP(log) BFLOAT_UNARY_OP(gelu) +BFLOAT_UNARY_OP(silu) BFLOAT_UNARY_OP(abs) BFLOAT_UNARY_OP(ceil) BFLOAT_UNARY_OP(floor) |