diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-02-28 21:02:41 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-28 21:02:41 +0100 |
commit | 4fd00b890036ef67391a9cc03f896247d0a75711 (patch) | |
tree | 19121cce4cab5406e9fda202de71a32427c096bd /candle-nn | |
parent | 57267cd53612ede04090853680125b17956804f3 (diff) | |
download | candle-4fd00b890036ef67391a9cc03f896247d0a75711.tar.gz candle-4fd00b890036ef67391a9cc03f896247d0a75711.tar.bz2 candle-4fd00b890036ef67391a9cc03f896247d0a75711.zip |
Add the StarCoder2 model. (#1779)
* Add the StarCoder2 model.
* Add the example code and get things to work.
* And also tweak the readme.
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/activation.rs | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs index 60a7a6d1..b9745375 100644 --- a/candle-nn/src/activation.rs +++ b/candle-nn/src/activation.rs @@ -5,6 +5,7 @@ use serde::Deserialize; #[serde(rename_all = "lowercase")] pub enum Activation { #[default] + #[serde(alias = "gelu")] Gelu, #[serde(alias = "gelu_new")] NewGelu, @@ -19,6 +20,8 @@ pub enum Activation { HardSwish, Elu(f64), LeakyRelu(f64), + #[serde(alias = "gelu_pytorch_tanh")] + GeluPytorchTanh, } impl super::Module for Activation { @@ -38,6 +41,7 @@ impl super::Module for Activation { Self::HardSwish => xs * crate::ops::hard_sigmoid(xs)?, &Self::Elu(alpha) => xs.elu(alpha), &Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope), + Self::GeluPytorchTanh => xs.gelu(), } } } |