From 3bbb88fcb463a6bdbb0e71c7b2d211dd02681493 Mon Sep 17 00:00:00 2001
From: MilkFather <31627231+MilkFather@users.noreply.github.com>
Date: Mon, 29 Apr 2024 17:04:43 +0800
Subject: Fix sigmoid gradient calculation and move sigmoid into a specialized
 op (#2114)

* add sigmoid op

* small fix

* add as a method on `Tensor`

* implement gradient calculation for sigmoid

* add sigmoid tests

* we should have a specialized op for this

* fix clippy

* fix clippy 2

* Revert all previous commits in favor of a `CustomOp` based solution

* use `CustomOp1` implementation

* fix rustfmt

* experimental add metal impl

* add cuda kernel impl

* fix fmt

* Add a test + reduce some cuda duplication.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
---
 candle-metal-kernels/src/lib.rs      | 2 +-
 candle-metal-kernels/src/unary.metal | 5 +++++
 2 files changed, 6 insertions(+), 1 deletion(-)

(limited to 'candle-metal-kernels')

diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 8e075d5a..c08e44fe 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -129,7 +129,7 @@ macro_rules! ops{
 pub mod unary {
     ops!(
         cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
-        tanh, recip, silu, sign
+        tanh, recip, silu, sign, sigmoid
     );
 }
 pub mod binary {
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal
index 143e9500..a82bfdbd 100644
--- a/candle-metal-kernels/src/unary.metal
+++ b/candle-metal-kernels/src/unary.metal
@@ -67,6 +67,9 @@ template <typename T> METAL_FUNC T relu(T in){
 template <typename T> METAL_FUNC T silu(T in){
     return in / (static_cast<T>(1) + exp(-in));
 }
+template <typename T> METAL_FUNC T sigmoid(T in) {
+    return recip(static_cast<T>(1) + exp(-in));
+}
 
 #define TILE_SIZE 2
 
@@ -155,6 +158,7 @@ UNARY_OP(tanh)
 UNARY_OP(recip)
 UNARY_OP(relu)
 UNARY_OP(sign)
+UNARY_OP(sigmoid)
 UNARY(id, float, copy_f32, copy_f32_strided)
 UNARY(id, half, copy_f16, copy_f16_strided)
 UNARY(id, uint8_t, copy_u8, copy_u8_strided)
@@ -185,6 +189,7 @@ BFLOAT_UNARY_OP(tanh)
 BFLOAT_UNARY_OP(recip)
 BFLOAT_UNARY_OP(relu)
 BFLOAT_UNARY_OP(sign)
+BFLOAT_UNARY_OP(sigmoid)
 
 UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
 
-- 
cgit v1.2.3