summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorOlivierDehaene <Olivier.dehaene@gmail.com>2024-02-14 10:27:22 +0100
committerGitHub <noreply@github.com>2024-02-14 10:27:22 +0100
commitb60064780d09ab6733f5287b322ea5cb057d3136 (patch)
tree38bd95dc351a046dd5b7c67fdf9203a32b06e38b /candle-metal-kernels
parent14010a8498af3383b004be1f55a2fa39bce5389d (diff)
downloadcandle-b60064780d09ab6733f5287b322ea5cb057d3136.tar.gz
candle-b60064780d09ab6733f5287b322ea5cb057d3136.tar.bz2
candle-b60064780d09ab6733f5287b322ea5cb057d3136.zip
feat: add silu activation function (#1706)
* feat: add silu activation function * use silu/arg in grad * update candle-nn * use node
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs2
-rw-r--r--candle-metal-kernels/src/tests.rs19
-rw-r--r--candle-metal-kernels/src/unary.metal5
3 files changed, 25 insertions, 1 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 2d27d230..33bc3453 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -183,7 +183,7 @@ macro_rules! ops{
pub mod unary {
ops!(
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
- tanh, recip
+ tanh, recip, silu
);
}
pub mod binary {
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index 655161e5..459c8edb 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -232,6 +232,25 @@ fn gelu_f32() {
}
#[test]
+fn silu_f16() {
+ let v: Vec<f16> = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]
+ .iter()
+ .map(|v| f16::from_f32(*v))
+ .collect();
+ let expected: Vec<f32> = vec![-0.0, -0.27, 0.0, 0.73, 1.76, 2.86, 10.0, 20.0];
+ let results = run(&v, unary::contiguous::silu::HALF);
+ assert_eq!(approx_f16(results, 2), expected);
+}
+
+#[test]
+fn silu_f32() {
+ let v: Vec<f32> = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0];
+ let expected: Vec<f32> = vec![-0.0, -0.269, 0.0, 0.731, 1.762, 2.858, 10.0, 20.0];
+ let results = run(&v, unary::contiguous::silu::FLOAT);
+ assert_eq!(approx(results, 3), expected);
+}
+
+#[test]
fn binary_add_f32() {
let left = vec![1.0f32, 2.0, 3.0];
let right = vec![2.0f32, 3.1, 4.2];
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal
index 7add58fd..1e0d5526 100644
--- a/candle-metal-kernels/src/unary.metal
+++ b/candle-metal-kernels/src/unary.metal
@@ -64,6 +64,9 @@ template <typename T> METAL_FUNC T relu(T in){
}
return in;
}
+template <typename T> METAL_FUNC T silu(T in){
+ return in / (static_cast<T>(1) + exp(-in));
+}
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
@@ -108,6 +111,7 @@ UNARY_OP(neg)
UNARY_OP(exp)
UNARY_OP(log)
UNARY_OP(gelu)
+UNARY_OP(silu)
UNARY_OP(abs)
UNARY_OP(ceil)
UNARY_OP(floor)
@@ -135,6 +139,7 @@ BFLOAT_UNARY_OP(neg)
BFLOAT_UNARY_OP(exp)
BFLOAT_UNARY_OP(log)
BFLOAT_UNARY_OP(gelu)
+BFLOAT_UNARY_OP(silu)
BFLOAT_UNARY_OP(abs)
BFLOAT_UNARY_OP(ceil)
BFLOAT_UNARY_OP(floor)