diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-11-02 20:01:34 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-02 20:01:34 +0100 |
commit | a2a20aeeccdcbcbd00e48d6f7ac97b2435b2378c (patch) | |
tree | 42fa7802e1b183670fc9b79aca60c09c28d69ff8 /candle-nn/src | |
parent | e08fbb654370ab465f01ba79f9f7b533cff03d15 (diff) | |
download | candle-a2a20aeeccdcbcbd00e48d6f7ac97b2435b2378c.tar.gz candle-a2a20aeeccdcbcbd00e48d6f7ac97b2435b2378c.tar.bz2 candle-a2a20aeeccdcbcbd00e48d6f7ac97b2435b2378c.zip |
Add the swiglu activation from the chatglm PR. (#1246)
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/activation.rs | 2 | ||||
-rw-r--r-- | candle-nn/src/ops.rs | 5 |
2 files changed, 7 insertions, 0 deletions
diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs index 799e2ee2..77e709d2 100644 --- a/candle-nn/src/activation.rs +++ b/candle-nn/src/activation.rs @@ -14,6 +14,7 @@ pub enum Activation { Silu, Sigmoid, HardSigmoid, + Swiglu, Swish, HardSwish, Elu(f64), @@ -32,6 +33,7 @@ impl super::Module for Activation { Self::Silu => crate::ops::silu(xs), Self::Sigmoid => crate::ops::sigmoid(xs), Self::HardSigmoid => crate::ops::hard_sigmoid(xs), + Self::Swiglu => crate::ops::swiglu(xs), Self::Swish => xs * crate::ops::sigmoid(xs)?, Self::HardSwish => xs * crate::ops::hard_sigmoid(xs)?, &Self::Elu(alpha) => xs.elu(alpha), diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index a51ef2e3..a0269e59 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -39,6 +39,11 @@ pub fn silu(xs: &Tensor) -> Result<Tensor> { xs / (xs.neg()?.exp()? + 1.0)? } +pub fn swiglu(xs: &Tensor) -> Result<Tensor> { + let xs = xs.chunk(2, candle::D::Minus1)?; + crate::ops::silu(&xs[0])? * &xs[1] +} + pub fn sigmoid(xs: &Tensor) -> Result<Tensor> { // TODO: Should we have a specialized op for this? (xs.neg()?.exp()? + 1.0)?.recip() |