diff options
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/activation.rs | 4 | ||||
-rw-r--r-- | candle-nn/src/ops.rs | 5 |
2 files changed, 9 insertions, 0 deletions
diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs index 79cf9c82..799e2ee2 100644 --- a/candle-nn/src/activation.rs +++ b/candle-nn/src/activation.rs @@ -13,7 +13,9 @@ pub enum Activation { Relu6, Silu, Sigmoid, + HardSigmoid, Swish, + HardSwish, Elu(f64), LeakyRelu(f64), } @@ -29,7 +31,9 @@ impl super::Module for Activation { Self::Relu6 => xs.clamp(0f32, 6f32), Self::Silu => crate::ops::silu(xs), Self::Sigmoid => crate::ops::sigmoid(xs), + Self::HardSigmoid => crate::ops::hard_sigmoid(xs), Self::Swish => xs * crate::ops::sigmoid(xs)?, + Self::HardSwish => xs * crate::ops::hard_sigmoid(xs)?, &Self::Elu(alpha) => xs.elu(alpha), &Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope), } diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index e9812108..a51ef2e3 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -44,6 +44,11 @@ pub fn sigmoid(xs: &Tensor) -> Result<Tensor> { (xs.neg()?.exp()? + 1.0)?.recip() } +pub fn hard_sigmoid(xs: &Tensor) -> Result<Tensor> { + // TODO: Should we have a specialized op for this? + ((xs + 3.0)? / 6.0)?.clamp(0f32, 1f32) +} + pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result<Tensor> { let zeros = xs.zeros_like()?; xs.maximum(&zeros)? + xs.minimum(&zeros)? * negative_slope |