diff options
author | jamjamjon <51357717+jamjamjon@users.noreply.github.com> | 2023-11-03 01:20:27 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-02 18:20:27 +0100 |
commit | d39d0c40fd27048cd8641d84f3e1da8b685302dd (patch) | |
tree | e4c6132d901cfea9908f9c03845b5705908daf80 /candle-nn | |
parent | b97463098ceeb80feabfd385d7062dfdb55068ee (diff) | |
download | candle-d39d0c40fd27048cd8641d84f3e1da8b685302dd.tar.gz candle-d39d0c40fd27048cd8641d84f3e1da8b685302dd.tar.bz2 candle-d39d0c40fd27048cd8641d84f3e1da8b685302dd.zip |
Add hard-sigmoid and hard-swish activations (#1244)
* Add hard-sigmoid and hard-swish activations
* Update ops.rs
* Use / rather than div.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/activation.rs | 4 | ||||
-rw-r--r-- | candle-nn/src/ops.rs | 5 |
2 files changed, 9 insertions, 0 deletions
diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs index 79cf9c82..799e2ee2 100644 --- a/candle-nn/src/activation.rs +++ b/candle-nn/src/activation.rs @@ -13,7 +13,9 @@ pub enum Activation { Relu6, Silu, Sigmoid, + HardSigmoid, Swish, + HardSwish, Elu(f64), LeakyRelu(f64), } @@ -29,7 +31,9 @@ impl super::Module for Activation { Self::Relu6 => xs.clamp(0f32, 6f32), Self::Silu => crate::ops::silu(xs), Self::Sigmoid => crate::ops::sigmoid(xs), + Self::HardSigmoid => crate::ops::hard_sigmoid(xs), Self::Swish => xs * crate::ops::sigmoid(xs)?, + Self::HardSwish => xs * crate::ops::hard_sigmoid(xs)?, &Self::Elu(alpha) => xs.elu(alpha), &Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope), } diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index e9812108..a51ef2e3 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -44,6 +44,11 @@ pub fn sigmoid(xs: &Tensor) -> Result<Tensor> { (xs.neg()?.exp()? + 1.0)?.recip() } +pub fn hard_sigmoid(xs: &Tensor) -> Result<Tensor> { + // TODO: Should we have a specialized op for this? + ((xs + 3.0)? / 6.0)?.clamp(0f32, 1f32) +} + pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result<Tensor> { let zeros = xs.zeros_like()?; xs.maximum(&zeros)? + xs.minimum(&zeros)? * negative_slope |